Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/distributed_file_utils.py: 29%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities that help manage directory path in distributed settings. 

16 

17In multi-worker training, the need to write a file to distributed file 

18location often requires only one copy done by one worker despite many workers 

19that are involved in training. The option to only perform saving by chief is 

20not feasible for a couple of reasons: 1) Chief and workers may each contain 

21a client that runs the same piece of code and it's preferred not to make 

22any distinction between the code run by chief and other workers, and 2) 

23saving of model or model's related information may require SyncOnRead 

24variables to be read, which needs the cooperation of all workers to perform 

25all-reduce. 

26 

27This set of utility is used so that only one copy is written to the needed 

28directory, by supplying a temporary write directory path for workers that don't 

29need to save, and removing the temporary directory once file writing is done. 

30 

31Example usage: 

32``` 

33# Before using a directory to write file to. 

34self.log_write_dir = write_dirpath(self.log_dir, get_distribution_strategy()) 

35# Now `self.log_write_dir` can be safely used to write file to. 

36 

37... 

38 

39# After the file is written to the directory. 

40remove_temp_dirpath(self.log_dir, get_distribution_strategy()) 

41 

42``` 

43 

44Experimental. API is subject to change. 

45""" 

46 

47import os 

48 

49import requests 

50import tensorflow.compat.v2 as tf 

51 

52GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"} 

53_GCE_METADATA_URL_ENV_VARIABLE = "GCE_METADATA_IP" 

54 

55 

56def _get_base_dirpath(strategy): 

57 task_id = strategy.extended._task_id 

58 return "workertemp_" + str(task_id) 

59 

60 

61def _is_temp_dir(dirpath, strategy): 

62 return dirpath.endswith(_get_base_dirpath(strategy)) 

63 

64 

65def _get_temp_dir(dirpath, strategy): 

66 if _is_temp_dir(dirpath, strategy): 

67 temp_dir = dirpath 

68 else: 

69 temp_dir = os.path.join(dirpath, _get_base_dirpath(strategy)) 

70 tf.io.gfile.makedirs(temp_dir) 

71 return temp_dir 

72 

73 

74def write_dirpath(dirpath, strategy): 

75 """Returns the writing dir that should be used to save file distributedly. 

76 

77 `dirpath` would be created if it doesn't exist. 

78 

79 Args: 

80 dirpath: Original dirpath that would be used without distribution. 

81 strategy: The tf.distribute strategy object currently used. 

82 

83 Returns: 

84 The writing dir path that should be used to save with distribution. 

85 """ 

86 if strategy is None: 

87 # Infer strategy from `tf.distribute` if not given. 

88 strategy = tf.distribute.get_strategy() 

89 if strategy is None: 

90 # If strategy is still not available, this is not in distributed 

91 # training. Fallback to original dirpath. 

92 return dirpath 

93 if not strategy.extended._in_multi_worker_mode(): 

94 return dirpath 

95 if strategy.extended.should_checkpoint: 

96 return dirpath 

97 # If this worker is not chief and hence should not save file, save it to a 

98 # temporary directory to be removed later. 

99 return _get_temp_dir(dirpath, strategy) 

100 

101 

102def remove_temp_dirpath(dirpath, strategy): 

103 """Removes the temp path after writing is finished. 

104 

105 Args: 

106 dirpath: Original dirpath that would be used without distribution. 

107 strategy: The tf.distribute strategy object currently used. 

108 """ 

109 if strategy is None: 

110 # Infer strategy from `tf.distribute` if not given. 

111 strategy = tf.distribute.get_strategy() 

112 if strategy is None: 

113 # If strategy is still not available, this is not in distributed 

114 # training. Fallback to no-op. 

115 return 

116 # TODO(anjalisridhar): Consider removing the check for multi worker mode 

117 # since it is redundant when used with the should_checkpoint property. 

118 if ( 

119 strategy.extended._in_multi_worker_mode() 

120 and not strategy.extended.should_checkpoint 

121 ): 

122 # If this worker is not chief and hence should not save file, remove 

123 # the temporary directory. 

124 tf.compat.v1.gfile.DeleteRecursively(_get_temp_dir(dirpath, strategy)) 

125 

126 

127def write_filepath(filepath, strategy): 

128 """Returns the writing file path to be used to save file distributedly. 

129 

130 Directory to contain `filepath` would be created if it doesn't exist. 

131 

132 Args: 

133 filepath: Original filepath that would be used without distribution. 

134 strategy: The tf.distribute strategy object currently used. 

135 

136 Returns: 

137 The writing filepath that should be used to save file with distribution. 

138 """ 

139 dirpath = os.path.dirname(filepath) 

140 base = os.path.basename(filepath) 

141 return os.path.join(write_dirpath(dirpath, strategy), base) 

142 

143 

144def remove_temp_dir_with_filepath(filepath, strategy): 

145 """Removes the temp path for file after writing is finished. 

146 

147 Args: 

148 filepath: Original filepath that would be used without distribution. 

149 strategy: The tf.distribute strategy object currently used. 

150 """ 

151 remove_temp_dirpath(os.path.dirname(filepath), strategy) 

152 

153 

154def _on_gcp(): 

155 """Detect whether the current running environment is on GCP.""" 

156 gce_metadata_endpoint = "http://" + os.environ.get( 

157 _GCE_METADATA_URL_ENV_VARIABLE, "metadata.google.internal" 

158 ) 

159 

160 try: 

161 # Timeout in 5 seconds, in case the test environment has connectivity 

162 # issue. There is not default timeout, which means it might block 

163 # forever. 

164 response = requests.get( 

165 f"{gce_metadata_endpoint}/computeMetadata/v1/{'instance/hostname'}", 

166 headers=GCP_METADATA_HEADER, 

167 timeout=5, 

168 ) 

169 return response.status_code 

170 except requests.exceptions.RequestException: 

171 return False 

172 

173 

174def support_on_demand_checkpoint_callback(strategy): 

175 if _on_gcp() and isinstance( 

176 strategy, tf.distribute.MultiWorkerMirroredStrategy 

177 ): 

178 return True 

179 

180 return False 

181