Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/utils.py: 33%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TF2 parameter server training utilities. 

16 

17Parameter server training in TF2 is currently under development. 

18""" 

19import threading 

20import time 

21 

22from tensorflow.python.platform import tf_logging as logging 

23from tensorflow.python.training import server_lib 

24 

25 

26def start_server(cluster_resolver, protocol): 

27 """Start a server and block the process from exiting.""" 

28 # This function is for multi-processing test or users who would like to have 

29 # every job run the same binary for simplicity. 

30 if not (cluster_resolver.task_type == 'worker' or 

31 cluster_resolver.task_type == 'ps'): 

32 raise ValueError('Unexpected task_type to start a server: {}'.format( 

33 cluster_resolver.task_type)) 

34 

35 server = server_lib.Server( 

36 cluster_resolver.cluster_spec().as_cluster_def(), 

37 job_name=cluster_resolver.task_type, 

38 task_index=cluster_resolver.task_id, 

39 protocol=protocol) 

40 

41 logging.info('TensorFlow server started for job %s, task %d.', 

42 cluster_resolver.task_type, cluster_resolver.task_id) 

43 

44 # Blocking the process that starts a server from exiting. 

45 server.join() 

46 

47 

48class RepeatedTimer(object): 

49 """Threaded Repeated Timer from http://shortn/_3hMZTFr1Iv.""" 

50 

51 def __init__(self, interval, function, *args): 

52 self._timer = None 

53 self.interval = interval 

54 self.function = function 

55 self.args = args 

56 self.start_time = time.time() 

57 self.is_running = False 

58 self.start() 

59 

60 def _get_duration_sec(self): 

61 return int(time.time() - self.start_time) 

62 

63 def _run(self): 

64 self.is_running = False 

65 self.start() 

66 self.function(*self.args) 

67 

68 def start(self): 

69 if not self.is_running: 

70 self._timer = threading.Timer(self.interval, self._run) 

71 self._timer.start() 

72 self.is_running = True 

73 

74 def stop(self): 

75 duration = self._get_duration_sec() 

76 self._timer.cancel() 

77 self.is_running = False 

78 return duration 

79