Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/utils.py: 33%
36 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TF2 parameter server training utilities.
17Parameter server training in TF2 is currently under development.
18"""
19import threading
20import time
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.training import server_lib
26def start_server(cluster_resolver, protocol):
27 """Start a server and block the process from exiting."""
28 # This function is for multi-processing test or users who would like to have
29 # every job run the same binary for simplicity.
30 if not (cluster_resolver.task_type == 'worker' or
31 cluster_resolver.task_type == 'ps'):
32 raise ValueError('Unexpected task_type to start a server: {}'.format(
33 cluster_resolver.task_type))
35 server = server_lib.Server(
36 cluster_resolver.cluster_spec().as_cluster_def(),
37 job_name=cluster_resolver.task_type,
38 task_index=cluster_resolver.task_id,
39 protocol=protocol)
41 logging.info('TensorFlow server started for job %s, task %d.',
42 cluster_resolver.task_type, cluster_resolver.task_id)
44 # Blocking the process that starts a server from exiting.
45 server.join()
48class RepeatedTimer(object):
49 """Threaded Repeated Timer from http://shortn/_3hMZTFr1Iv."""
51 def __init__(self, interval, function, *args):
52 self._timer = None
53 self.interval = interval
54 self.function = function
55 self.args = args
56 self.start_time = time.time()
57 self.is_running = False
58 self.start()
60 def _get_duration_sec(self):
61 return int(time.time() - self.start_time)
63 def _run(self):
64 self.is_running = False
65 self.start()
66 self.function(*self.args)
68 def start(self):
69 if not self.is_running:
70 self._timer = threading.Timer(self.interval, self._run)
71 self._timer.start()
72 self.is_running = True
74 def stop(self):
75 duration = self._get_duration_sec()
76 self._timer.cancel()
77 self.is_running = False
78 return duration