Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/metric_utils.py: 24%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Metrics collecting utilities for single client training."""
17import time
19from tensorflow.python.eager import monitoring
20from tensorflow.python.util import tf_contextlib
22enable_metrics = False
23_METRICS_MAPPING = {}
26def _init():
27 """Initialize the metrics mapping."""
28 global _METRICS_MAPPING
30 # Time in seconds to bucket the distribution of execution time. Range from
31 # 0.001s (i.e., 1ms) to 1000s.
32 time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
34 function_tracing_sampler = monitoring.Sampler(
35 '/tensorflow/api/ps_strategy/coordinator/function_tracing', time_buckets,
36 'Sampler to track the time (in seconds) for tracing functions.')
38 closure_execution_sampler = monitoring.Sampler(
39 '/tensorflow/api/ps_strategy/coordinator/closure_execution',
40 time_buckets,
41 'Sampler to track the time (in seconds) for executing closures.')
43 remote_value_fetch_sampler = monitoring.Sampler(
44 '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch',
45 time_buckets,
46 'Sampler to track the time (in seconds) for fetching remote_value.')
48 server_def_update_sampler = monitoring.Sampler(
49 '/tensorflow/api/ps_strategy/coordinator/server_def_update', time_buckets,
50 'Sample to track the time (in seconds) for updating the server def upon '
51 'worker recovery.')
53 _METRICS_MAPPING = {
54 'function_tracing': function_tracing_sampler,
55 'closure_execution': closure_execution_sampler,
56 'remote_value_fetch': remote_value_fetch_sampler,
57 'server_def_update': server_def_update_sampler,
58 }
61@tf_contextlib.contextmanager
62def monitored_timer(metric_name, state_tracker=None):
63 """Monitor the execution time and collect it into the specified metric."""
64 if not enable_metrics:
65 yield
66 else:
67 if not _METRICS_MAPPING:
68 _init()
69 start_time = time.time()
70 start_state = state_tracker() if state_tracker else None
71 yield
72 duration_sec = time.time() - start_time
73 # If a state_checker is provided, record the metric only if the end state is
74 # different from the start state.
75 if state_tracker is None or state_tracker() != start_state:
76 metric = _METRICS_MAPPING[metric_name]
77 metric.get_cell().add(duration_sec)
80def get_metric_summary(metric_name):
81 """Get summary for the specified metric."""
82 metric = _METRICS_MAPPING[metric_name]
83 histogram_proto = metric.get_cell().value()
84 ret = dict()
85 ret['min'] = histogram_proto.min
86 ret['max'] = histogram_proto.max
87 ret['num'] = histogram_proto.num
88 ret['sum'] = histogram_proto.sum
90 bucket_limits = histogram_proto.bucket_limit
91 bucket_vals = histogram_proto.bucket
92 ret['histogram'] = {}
93 # Add lower limit as 0, since all these metrics are durations
94 bucket_limits.insert(0, 0)
95 for lb, ub, val in zip(bucket_limits[:-1], bucket_limits[1:], bucket_vals):
96 ret['histogram'][(lb, ub)] = val
97 return ret