1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains function to log if devices are compatible with mixed precision."""
16
17import itertools
18
19from tensorflow.python.framework import config
20from tensorflow.python.platform import tf_logging
21
22
23_COMPAT_CHECK_PREFIX = 'Mixed precision compatibility check (mixed_float16): '
24_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + 'OK'
25_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + 'WARNING'
26_COMPAT_CHECK_WARNING_SUFFIX = (
27 'If you will use compatible GPU(s) not attached to this host, e.g. by '
28 'running a multi-worker model, you can ignore this warning. This message '
29 'will only be logged once')
30
31
32def _dedup_strings(device_strs):
33 """Groups together consecutive identical strings.
34
35 For example, given:
36 ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3']
37 This function returns:
38 ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)']
39
40 Args:
41 device_strs: A list of strings, each representing a device.
42
43 Returns:
44 A copy of the input, but identical consecutive strings are merged into a
45 single string.
46 """
47 new_device_strs = []
48 for device_str, vals in itertools.groupby(device_strs):
49 num = len(list(vals))
50 if num == 1:
51 new_device_strs.append(device_str)
52 else:
53 new_device_strs.append('%s (x%d)' % (device_str, num))
54 return new_device_strs
55
56
57def _log_device_compatibility_check(policy_name, gpu_details_list):
58 """Logs a compatibility check if the devices support the policy.
59
60 Currently only logs for the policy mixed_float16.
61
62 Args:
63 policy_name: The name of the dtype policy.
64 gpu_details_list: A list of dicts, one dict per GPU. Each dict
65 is the device details for a GPU, as returned by
66 `tf.config.experimental.get_device_details()`.
67 """
68 if policy_name != 'mixed_float16':
69 # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires
70 # checking if a TPU is available.
71 return
72 supported_device_strs = []
73 unsupported_device_strs = []
74 for details in gpu_details_list:
75 name = details.get('device_name', 'Unknown GPU')
76 cc = details.get('compute_capability')
77 if cc:
78 device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
79 if cc >= (7, 0):
80 supported_device_strs.append(device_str)
81 else:
82 unsupported_device_strs.append(device_str)
83 else:
84 unsupported_device_strs.append(
85 name + ', no compute capability (probably not an Nvidia GPU)')
86
87 if unsupported_device_strs:
88 warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n'
89 if supported_device_strs:
90 warning_str += ('Some of your GPUs may run slowly with dtype policy '
91 'mixed_float16 because they do not all have compute '
92 'capability of at least 7.0. Your GPUs:\n')
93 elif len(unsupported_device_strs) == 1:
94 warning_str += ('Your GPU may run slowly with dtype policy mixed_float16 '
95 'because it does not have compute capability of at least '
96 '7.0. Your GPU:\n')
97 else:
98 warning_str += ('Your GPUs may run slowly with dtype policy '
99 'mixed_float16 because they do not have compute '
100 'capability of at least 7.0. Your GPUs:\n')
101 for device_str in _dedup_strings(supported_device_strs +
102 unsupported_device_strs):
103 warning_str += ' ' + device_str + '\n'
104 warning_str += ('See https://developer.nvidia.com/cuda-gpus for a list of '
105 'GPUs and their compute capabilities.\n')
106 warning_str += _COMPAT_CHECK_WARNING_SUFFIX
107 tf_logging.warning(warning_str)
108 elif not supported_device_strs:
109 tf_logging.warning(
110 '%s\n'
111 'The dtype policy mixed_float16 may run slowly because '
112 'this machine does not have a GPU. Only Nvidia GPUs with '
113 'compute capability of at least 7.0 run quickly with '
114 'mixed_float16.\n%s' % (_COMPAT_CHECK_WARNING_PREFIX,
115 _COMPAT_CHECK_WARNING_SUFFIX))
116 elif len(supported_device_strs) == 1:
117 tf_logging.info('%s\n'
118 'Your GPU will likely run quickly with dtype policy '
119 'mixed_float16 as it has compute capability of at least '
120 '7.0. Your GPU: %s' % (_COMPAT_CHECK_OK_PREFIX,
121 supported_device_strs[0]))
122 else:
123 tf_logging.info('%s\n'
124 'Your GPUs will likely run quickly with dtype policy '
125 'mixed_float16 as they all have compute capability of at '
126 'least 7.0' % _COMPAT_CHECK_OK_PREFIX)
127
128
129_logged_compatibility_check = False
130
131
132def log_device_compatibility_check(policy_name):
133 """Logs a compatibility check if the devices support the policy.
134
135 Currently only logs for the policy mixed_float16. A log is shown only the
136 first time this function is called.
137
138 Args:
139 policy_name: The name of the dtype policy.
140 """
141 global _logged_compatibility_check
142 if _logged_compatibility_check:
143 return
144 _logged_compatibility_check = True
145 gpus = config.list_physical_devices('GPU')
146 gpu_details_list = [config.get_device_details(g) for g in gpus]
147 _log_device_compatibility_check(policy_name, gpu_details_list)