Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/device_compatibility_check.py: 22%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains function to log if devices are compatible with mixed precision."""
17import itertools
19import tensorflow.compat.v2 as tf
21# isort: off
22from tensorflow.python.platform import tf_logging
24_COMPAT_CHECK_PREFIX = "Mixed precision compatibility check (mixed_float16): "
25_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + "OK"
26_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + "WARNING"
27_COMPAT_CHECK_WARNING_SUFFIX = (
28 "If you will use compatible GPU(s) not attached to this host, e.g. by "
29 "running a multi-worker model, you can ignore this warning. This message "
30 "will only be logged once"
31)
34def _dedup_strings(device_strs):
35 """Groups together consecutive identical strings.
37 For example, given:
38 ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3']
39 This function returns:
40 ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)']
42 Args:
43 device_strs: A list of strings, each representing a device.
45 Returns:
46 A copy of the input, but identical consecutive strings are merged into a
47 single string.
48 """
49 new_device_strs = []
50 for device_str, vals in itertools.groupby(device_strs):
51 num = len(list(vals))
52 if num == 1:
53 new_device_strs.append(device_str)
54 else:
55 new_device_strs.append("%s (x%d)" % (device_str, num))
56 return new_device_strs
59def _log_device_compatibility_check(policy_name, gpu_details_list):
60 """Logs a compatibility check if the devices support the policy.
62 Currently only logs for the policy mixed_float16.
64 Args:
65 policy_name: The name of the dtype policy.
66 gpu_details_list: A list of dicts, one dict per GPU. Each dict
67 is the device details for a GPU, as returned by
68 `tf.config.experimental.get_device_details()`.
69 """
70 if policy_name != "mixed_float16":
71 # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This
72 # requires checking if a TPU is available.
73 return
74 supported_device_strs = []
75 unsupported_device_strs = []
76 for details in gpu_details_list:
77 name = details.get("device_name", "Unknown GPU")
78 cc = details.get("compute_capability")
79 if cc:
80 device_str = f"{name}, compute capability {cc[0]}.{cc[1]}"
81 if cc >= (7, 0):
82 supported_device_strs.append(device_str)
83 else:
84 unsupported_device_strs.append(device_str)
85 else:
86 unsupported_device_strs.append(
87 name + ", no compute capability (probably not an Nvidia GPU)"
88 )
90 if unsupported_device_strs:
91 warning_str = _COMPAT_CHECK_WARNING_PREFIX + "\n"
92 if supported_device_strs:
93 warning_str += (
94 "Some of your GPUs may run slowly with dtype policy "
95 "mixed_float16 because they do not all have compute "
96 "capability of at least 7.0. Your GPUs:\n"
97 )
98 elif len(unsupported_device_strs) == 1:
99 warning_str += (
100 "Your GPU may run slowly with dtype policy mixed_float16 "
101 "because it does not have compute capability of at least "
102 "7.0. Your GPU:\n"
103 )
104 else:
105 warning_str += (
106 "Your GPUs may run slowly with dtype policy "
107 "mixed_float16 because they do not have compute "
108 "capability of at least 7.0. Your GPUs:\n"
109 )
110 for device_str in _dedup_strings(
111 supported_device_strs + unsupported_device_strs
112 ):
113 warning_str += " " + device_str + "\n"
114 warning_str += (
115 "See https://developer.nvidia.com/cuda-gpus for a list of "
116 "GPUs and their compute capabilities.\n"
117 )
118 warning_str += _COMPAT_CHECK_WARNING_SUFFIX
119 tf_logging.warning(warning_str)
120 elif not supported_device_strs:
121 tf_logging.warning(
122 "%s\n"
123 "The dtype policy mixed_float16 may run slowly because "
124 "this machine does not have a GPU. Only Nvidia GPUs with "
125 "compute capability of at least 7.0 run quickly with "
126 "mixed_float16.\n%s"
127 % (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX)
128 )
129 elif len(supported_device_strs) == 1:
130 tf_logging.info(
131 "%s\n"
132 "Your GPU will likely run quickly with dtype policy "
133 "mixed_float16 as it has compute capability of at least "
134 "7.0. Your GPU: %s"
135 % (_COMPAT_CHECK_OK_PREFIX, supported_device_strs[0])
136 )
137 else:
138 tf_logging.info(
139 "%s\n"
140 "Your GPUs will likely run quickly with dtype policy "
141 "mixed_float16 as they all have compute capability of at "
142 "least 7.0" % _COMPAT_CHECK_OK_PREFIX
143 )
146_logged_compatibility_check = False
149def log_device_compatibility_check(policy_name):
150 """Logs a compatibility check if the devices support the policy.
152 Currently only logs for the policy mixed_float16. A log is shown only the
153 first time this function is called.
155 Args:
156 policy_name: The name of the dtype policy.
157 """
158 global _logged_compatibility_check
159 if _logged_compatibility_check:
160 return
161 _logged_compatibility_check = True
162 gpus = tf.config.list_physical_devices("GPU")
163 gpu_details_list = [
164 tf.config.experimental.get_device_details(g) for g in gpus
165 ]
166 _log_device_compatibility_check(policy_name, gpu_details_list)