Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/py_checkpoint_reader.py: 41%
37 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extending CheckpointReader for TensorFlow."""
16from tensorflow.python.framework import dtypes
17from tensorflow.python.framework import errors_impl
18from tensorflow.python.util import compat
19from tensorflow.python.util._pywrap_checkpoint_reader import CheckpointReader
20from tensorflow.python.util.tf_export import tf_export
23def error_translator(e):
24 """Translate the tensor_slice_reader.cc errors."""
25 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
26 # issue with throwing python exceptions from C++.
27 error_message = str(e)
28 if 'not found in checkpoint' in error_message or (
29 'Failed to find any '
30 'matching files for') in error_message:
31 raise errors_impl.NotFoundError(None, None, error_message)
32 elif 'Sliced checkpoints are not supported' in error_message or (
33 'Data type '
34 'not '
35 'supported') in error_message:
36 raise errors_impl.UnimplementedError(None, None, error_message)
37 elif 'Failed to get matching files on' in error_message:
38 raise errors_impl.InvalidArgumentError(None, None, error_message)
39 elif 'Unable to open table file' in error_message:
40 raise errors_impl.DataLossError(None, None, error_message)
41 elif 'Failed to find the saved tensor slices' in error_message or (
42 'not convertible to numpy dtype' in error_message):
43 raise errors_impl.InternalError(None, None, error_message)
44 else:
45 raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
48def get_variable_to_dtype_map(self):
49 return {
50 name: dtypes.DType(type_enum)
51 for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access
52 }
54CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map
57def has_tensor(self, tensor_str):
58 return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access
60CheckpointReader.has_tensor = has_tensor
63def get_tensor(self, tensor_str):
64 """Get the tensor from the Checkpoint object."""
65 try:
66 return CheckpointReader.CheckpointReader_GetTensor(
67 self, compat.as_bytes(tensor_str))
68 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
69 # issue with throwing python exceptions from C++.
70 except RuntimeError as e:
71 error_translator(e)
74CheckpointReader.get_tensor = get_tensor
77# Disable invalid name to keep backwards compatibility with that function.
78# It was previously exported from py_checkpoint_reader.i which did not conform
79# to pylint checks.
80# pylint: disable=invalid-name
81@tf_export(v1=['train.NewCheckpointReader'])
82def NewCheckpointReader(filepattern):
83 """A function that returns a CheckPointReader.
85 Args:
86 filepattern: The filename.
88 Returns:
89 A CheckpointReader object.
90 """
91 try:
92 return CheckpointReader(compat.as_bytes(filepattern))
93 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
94 # issue with throwing python exceptions from C++.
95 except RuntimeError as e:
96 error_translator(e)