Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/device_compatibility_check.py: 22%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains function to log if devices are compatible with mixed precision.""" 

16 

17import itertools 

18 

19from tensorflow.python.framework import config 

20from tensorflow.python.platform import tf_logging 

21 

22 

23_COMPAT_CHECK_PREFIX = 'Mixed precision compatibility check (mixed_float16): ' 

24_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + 'OK' 

25_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + 'WARNING' 

26_COMPAT_CHECK_WARNING_SUFFIX = ( 

27 'If you will use compatible GPU(s) not attached to this host, e.g. by ' 

28 'running a multi-worker model, you can ignore this warning. This message ' 

29 'will only be logged once') 

30 

31 

32def _dedup_strings(device_strs): 

33 """Groups together consecutive identical strings. 

34 

35 For example, given: 

36 ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3'] 

37 This function returns: 

38 ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)'] 

39 

40 Args: 

41 device_strs: A list of strings, each representing a device. 

42 

43 Returns: 

44 A copy of the input, but identical consecutive strings are merged into a 

45 single string. 

46 """ 

47 new_device_strs = [] 

48 for device_str, vals in itertools.groupby(device_strs): 

49 num = len(list(vals)) 

50 if num == 1: 

51 new_device_strs.append(device_str) 

52 else: 

53 new_device_strs.append('%s (x%d)' % (device_str, num)) 

54 return new_device_strs 

55 

56 

57def _log_device_compatibility_check(policy_name, gpu_details_list): 

58 """Logs a compatibility check if the devices support the policy. 

59 

60 Currently only logs for the policy mixed_float16. 

61 

62 Args: 

63 policy_name: The name of the dtype policy. 

64 gpu_details_list: A list of dicts, one dict per GPU. Each dict 

65 is the device details for a GPU, as returned by 

66 `tf.config.experimental.get_device_details()`. 

67 """ 

68 if policy_name != 'mixed_float16': 

69 # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires 

70 # checking if a TPU is available. 

71 return 

72 supported_device_strs = [] 

73 unsupported_device_strs = [] 

74 for details in gpu_details_list: 

75 name = details.get('device_name', 'Unknown GPU') 

76 cc = details.get('compute_capability') 

77 if cc: 

78 device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1]) 

79 if cc >= (7, 0): 

80 supported_device_strs.append(device_str) 

81 else: 

82 unsupported_device_strs.append(device_str) 

83 else: 

84 unsupported_device_strs.append( 

85 name + ', no compute capability (probably not an Nvidia GPU)') 

86 

87 if unsupported_device_strs: 

88 warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n' 

89 if supported_device_strs: 

90 warning_str += ('Some of your GPUs may run slowly with dtype policy ' 

91 'mixed_float16 because they do not all have compute ' 

92 'capability of at least 7.0. Your GPUs:\n') 

93 elif len(unsupported_device_strs) == 1: 

94 warning_str += ('Your GPU may run slowly with dtype policy mixed_float16 ' 

95 'because it does not have compute capability of at least ' 

96 '7.0. Your GPU:\n') 

97 else: 

98 warning_str += ('Your GPUs may run slowly with dtype policy ' 

99 'mixed_float16 because they do not have compute ' 

100 'capability of at least 7.0. Your GPUs:\n') 

101 for device_str in _dedup_strings(supported_device_strs + 

102 unsupported_device_strs): 

103 warning_str += ' ' + device_str + '\n' 

104 warning_str += ('See https://developer.nvidia.com/cuda-gpus for a list of ' 

105 'GPUs and their compute capabilities.\n') 

106 warning_str += _COMPAT_CHECK_WARNING_SUFFIX 

107 tf_logging.warning(warning_str) 

108 elif not supported_device_strs: 

109 tf_logging.warning( 

110 '%s\n' 

111 'The dtype policy mixed_float16 may run slowly because ' 

112 'this machine does not have a GPU. Only Nvidia GPUs with ' 

113 'compute capability of at least 7.0 run quickly with ' 

114 'mixed_float16.\n%s' % (_COMPAT_CHECK_WARNING_PREFIX, 

115 _COMPAT_CHECK_WARNING_SUFFIX)) 

116 elif len(supported_device_strs) == 1: 

117 tf_logging.info('%s\n' 

118 'Your GPU will likely run quickly with dtype policy ' 

119 'mixed_float16 as it has compute capability of at least ' 

120 '7.0. Your GPU: %s' % (_COMPAT_CHECK_OK_PREFIX, 

121 supported_device_strs[0])) 

122 else: 

123 tf_logging.info('%s\n' 

124 'Your GPUs will likely run quickly with dtype policy ' 

125 'mixed_float16 as they all have compute capability of at ' 

126 'least 7.0' % _COMPAT_CHECK_OK_PREFIX) 

127 

128 

129_logged_compatibility_check = False 

130 

131 

132def log_device_compatibility_check(policy_name): 

133 """Logs a compatibility check if the devices support the policy. 

134 

135 Currently only logs for the policy mixed_float16. A log is shown only the 

136 first time this function is called. 

137 

138 Args: 

139 policy_name: The name of the dtype policy. 

140 """ 

141 global _logged_compatibility_check 

142 if _logged_compatibility_check: 

143 return 

144 _logged_compatibility_check = True 

145 gpus = config.list_physical_devices('GPU') 

146 gpu_details_list = [config.get_device_details(g) for g in gpus] 

147 _log_device_compatibility_check(policy_name, gpu_details_list)