Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/device_compatibility_check.py: 22%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains function to log if devices are compatible with mixed precision.""" 

16 

17import itertools 

18 

19import tensorflow.compat.v2 as tf 

20 

21# isort: off 

22from tensorflow.python.platform import tf_logging 

23 

24_COMPAT_CHECK_PREFIX = "Mixed precision compatibility check (mixed_float16): " 

25_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + "OK" 

26_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + "WARNING" 

27_COMPAT_CHECK_WARNING_SUFFIX = ( 

28 "If you will use compatible GPU(s) not attached to this host, e.g. by " 

29 "running a multi-worker model, you can ignore this warning. This message " 

30 "will only be logged once" 

31) 

32 

33 

34def _dedup_strings(device_strs): 

35 """Groups together consecutive identical strings. 

36 

37 For example, given: 

38 ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3'] 

39 This function returns: 

40 ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)'] 

41 

42 Args: 

43 device_strs: A list of strings, each representing a device. 

44 

45 Returns: 

46 A copy of the input, but identical consecutive strings are merged into a 

47 single string. 

48 """ 

49 new_device_strs = [] 

50 for device_str, vals in itertools.groupby(device_strs): 

51 num = len(list(vals)) 

52 if num == 1: 

53 new_device_strs.append(device_str) 

54 else: 

55 new_device_strs.append("%s (x%d)" % (device_str, num)) 

56 return new_device_strs 

57 

58 

59def _log_device_compatibility_check(policy_name, gpu_details_list): 

60 """Logs a compatibility check if the devices support the policy. 

61 

62 Currently only logs for the policy mixed_float16. 

63 

64 Args: 

65 policy_name: The name of the dtype policy. 

66 gpu_details_list: A list of dicts, one dict per GPU. Each dict 

67 is the device details for a GPU, as returned by 

68 `tf.config.experimental.get_device_details()`. 

69 """ 

70 if policy_name != "mixed_float16": 

71 # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This 

72 # requires checking if a TPU is available. 

73 return 

74 supported_device_strs = [] 

75 unsupported_device_strs = [] 

76 for details in gpu_details_list: 

77 name = details.get("device_name", "Unknown GPU") 

78 cc = details.get("compute_capability") 

79 if cc: 

80 device_str = f"{name}, compute capability {cc[0]}.{cc[1]}" 

81 if cc >= (7, 0): 

82 supported_device_strs.append(device_str) 

83 else: 

84 unsupported_device_strs.append(device_str) 

85 else: 

86 unsupported_device_strs.append( 

87 name + ", no compute capability (probably not an Nvidia GPU)" 

88 ) 

89 

90 if unsupported_device_strs: 

91 warning_str = _COMPAT_CHECK_WARNING_PREFIX + "\n" 

92 if supported_device_strs: 

93 warning_str += ( 

94 "Some of your GPUs may run slowly with dtype policy " 

95 "mixed_float16 because they do not all have compute " 

96 "capability of at least 7.0. Your GPUs:\n" 

97 ) 

98 elif len(unsupported_device_strs) == 1: 

99 warning_str += ( 

100 "Your GPU may run slowly with dtype policy mixed_float16 " 

101 "because it does not have compute capability of at least " 

102 "7.0. Your GPU:\n" 

103 ) 

104 else: 

105 warning_str += ( 

106 "Your GPUs may run slowly with dtype policy " 

107 "mixed_float16 because they do not have compute " 

108 "capability of at least 7.0. Your GPUs:\n" 

109 ) 

110 for device_str in _dedup_strings( 

111 supported_device_strs + unsupported_device_strs 

112 ): 

113 warning_str += " " + device_str + "\n" 

114 warning_str += ( 

115 "See https://developer.nvidia.com/cuda-gpus for a list of " 

116 "GPUs and their compute capabilities.\n" 

117 ) 

118 warning_str += _COMPAT_CHECK_WARNING_SUFFIX 

119 tf_logging.warning(warning_str) 

120 elif not supported_device_strs: 

121 tf_logging.warning( 

122 "%s\n" 

123 "The dtype policy mixed_float16 may run slowly because " 

124 "this machine does not have a GPU. Only Nvidia GPUs with " 

125 "compute capability of at least 7.0 run quickly with " 

126 "mixed_float16.\n%s" 

127 % (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX) 

128 ) 

129 elif len(supported_device_strs) == 1: 

130 tf_logging.info( 

131 "%s\n" 

132 "Your GPU will likely run quickly with dtype policy " 

133 "mixed_float16 as it has compute capability of at least " 

134 "7.0. Your GPU: %s" 

135 % (_COMPAT_CHECK_OK_PREFIX, supported_device_strs[0]) 

136 ) 

137 else: 

138 tf_logging.info( 

139 "%s\n" 

140 "Your GPUs will likely run quickly with dtype policy " 

141 "mixed_float16 as they all have compute capability of at " 

142 "least 7.0" % _COMPAT_CHECK_OK_PREFIX 

143 ) 

144 

145 

146_logged_compatibility_check = False 

147 

148 

149def log_device_compatibility_check(policy_name): 

150 """Logs a compatibility check if the devices support the policy. 

151 

152 Currently only logs for the policy mixed_float16. A log is shown only the 

153 first time this function is called. 

154 

155 Args: 

156 policy_name: The name of the dtype policy. 

157 """ 

158 global _logged_compatibility_check 

159 if _logged_compatibility_check: 

160 return 

161 _logged_compatibility_check = True 

162 gpus = tf.config.list_physical_devices("GPU") 

163 gpu_details_list = [ 

164 tf.config.experimental.get_device_details(g) for g in gpus 

165 ] 

166 _log_device_compatibility_check(policy_name, gpu_details_list) 

167