Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/metric_utils.py: 24%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Metrics collecting utilities for single client training.""" 

16 

17import time 

18 

19from tensorflow.python.eager import monitoring 

20from tensorflow.python.util import tf_contextlib 

21 

22enable_metrics = False 

23_METRICS_MAPPING = {} 

24 

25 

26def _init(): 

27 """Initialize the metrics mapping.""" 

28 global _METRICS_MAPPING 

29 

30 # Time in seconds to bucket the distribution of execution time. Range from 

31 # 0.001s (i.e., 1ms) to 1000s. 

32 time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6) 

33 

34 function_tracing_sampler = monitoring.Sampler( 

35 '/tensorflow/api/ps_strategy/coordinator/function_tracing', time_buckets, 

36 'Sampler to track the time (in seconds) for tracing functions.') 

37 

38 closure_execution_sampler = monitoring.Sampler( 

39 '/tensorflow/api/ps_strategy/coordinator/closure_execution', 

40 time_buckets, 

41 'Sampler to track the time (in seconds) for executing closures.') 

42 

43 remote_value_fetch_sampler = monitoring.Sampler( 

44 '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', 

45 time_buckets, 

46 'Sampler to track the time (in seconds) for fetching remote_value.') 

47 

48 server_def_update_sampler = monitoring.Sampler( 

49 '/tensorflow/api/ps_strategy/coordinator/server_def_update', time_buckets, 

50 'Sample to track the time (in seconds) for updating the server def upon ' 

51 'worker recovery.') 

52 

53 _METRICS_MAPPING = { 

54 'function_tracing': function_tracing_sampler, 

55 'closure_execution': closure_execution_sampler, 

56 'remote_value_fetch': remote_value_fetch_sampler, 

57 'server_def_update': server_def_update_sampler, 

58 } 

59 

60 

61@tf_contextlib.contextmanager 

62def monitored_timer(metric_name, state_tracker=None): 

63 """Monitor the execution time and collect it into the specified metric.""" 

64 if not enable_metrics: 

65 yield 

66 else: 

67 if not _METRICS_MAPPING: 

68 _init() 

69 start_time = time.time() 

70 start_state = state_tracker() if state_tracker else None 

71 yield 

72 duration_sec = time.time() - start_time 

73 # If a state_checker is provided, record the metric only if the end state is 

74 # different from the start state. 

75 if state_tracker is None or state_tracker() != start_state: 

76 metric = _METRICS_MAPPING[metric_name] 

77 metric.get_cell().add(duration_sec) 

78 

79 

80def get_metric_summary(metric_name): 

81 """Get summary for the specified metric.""" 

82 metric = _METRICS_MAPPING[metric_name] 

83 histogram_proto = metric.get_cell().value() 

84 ret = dict() 

85 ret['min'] = histogram_proto.min 

86 ret['max'] = histogram_proto.max 

87 ret['num'] = histogram_proto.num 

88 ret['sum'] = histogram_proto.sum 

89 

90 bucket_limits = histogram_proto.bucket_limit 

91 bucket_vals = histogram_proto.bucket 

92 ret['histogram'] = {} 

93 # Add lower limit as 0, since all these metrics are durations 

94 bucket_limits.insert(0, 0) 

95 for lb, ub, val in zip(bucket_limits[:-1], bucket_limits[1:], bucket_vals): 

96 ret['histogram'][(lb, ub)] = val 

97 return ret