Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/py_checkpoint_reader.py: 41%

37 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Extending CheckpointReader for TensorFlow.""" 

16from tensorflow.python.framework import dtypes 

17from tensorflow.python.framework import errors_impl 

18from tensorflow.python.util import compat 

19from tensorflow.python.util._pywrap_checkpoint_reader import CheckpointReader 

20from tensorflow.python.util.tf_export import tf_export 

21 

22 

23def error_translator(e): 

24 """Translate the tensor_slice_reader.cc errors.""" 

25 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 

26 # issue with throwing python exceptions from C++. 

27 error_message = str(e) 

28 if 'not found in checkpoint' in error_message or ( 

29 'Failed to find any ' 

30 'matching files for') in error_message: 

31 raise errors_impl.NotFoundError(None, None, error_message) 

32 elif 'Sliced checkpoints are not supported' in error_message or ( 

33 'Data type ' 

34 'not ' 

35 'supported') in error_message: 

36 raise errors_impl.UnimplementedError(None, None, error_message) 

37 elif 'Failed to get matching files on' in error_message: 

38 raise errors_impl.InvalidArgumentError(None, None, error_message) 

39 elif 'Unable to open table file' in error_message: 

40 raise errors_impl.DataLossError(None, None, error_message) 

41 elif 'Failed to find the saved tensor slices' in error_message or ( 

42 'not convertible to numpy dtype' in error_message): 

43 raise errors_impl.InternalError(None, None, error_message) 

44 else: 

45 raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN) 

46 

47 

48def get_variable_to_dtype_map(self): 

49 return { 

50 name: dtypes.DType(type_enum) 

51 for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access 

52 } 

53 

54CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map 

55 

56 

57def has_tensor(self, tensor_str): 

58 return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access 

59 

60CheckpointReader.has_tensor = has_tensor 

61 

62 

63def get_tensor(self, tensor_str): 

64 """Get the tensor from the Checkpoint object.""" 

65 try: 

66 return CheckpointReader.CheckpointReader_GetTensor( 

67 self, compat.as_bytes(tensor_str)) 

68 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 

69 # issue with throwing python exceptions from C++. 

70 except RuntimeError as e: 

71 error_translator(e) 

72 

73 

74CheckpointReader.get_tensor = get_tensor 

75 

76 

77# Disable invalid name to keep backwards compatibility with that function. 

78# It was previously exported from py_checkpoint_reader.i which did not conform 

79# to pylint checks. 

80# pylint: disable=invalid-name 

81@tf_export(v1=['train.NewCheckpointReader']) 

82def NewCheckpointReader(filepattern): 

83 """A function that returns a CheckPointReader. 

84 

85 Args: 

86 filepattern: The filename. 

87 

88 Returns: 

89 A CheckpointReader object. 

90 """ 

91 try: 

92 return CheckpointReader(compat.as_bytes(filepattern)) 

93 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 

94 # issue with throwing python exceptions from C++. 

95 except RuntimeError as e: 

96 error_translator(e)