Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/model_utils/mode_keys.py: 88%

40 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# LINT.IfChange 

16"""Utils for managing different mode strings used by Keras and Estimator models. 

17""" 

18 

19from tensorflow.python.util.compat import collections_abc 

20 

21 

22class KerasModeKeys(object): 

23 """Standard names for model modes. 

24 

25 The following standard keys are defined: 

26 

27 * `TRAIN`: training/fitting mode. 

28 * `TEST`: testing/evaluation mode. 

29 * `PREDICT`: prediction/inference mode. 

30 """ 

31 

32 TRAIN = 'train' 

33 TEST = 'test' 

34 PREDICT = 'predict' 

35 

36 

37# TODO(kathywu): Remove copy in Estimator after nightlies 

38class EstimatorModeKeys(object): 

39 """Standard names for Estimator model modes. 

40 

41 The following standard keys are defined: 

42 

43 * `TRAIN`: training/fitting mode. 

44 * `EVAL`: testing/evaluation mode. 

45 * `PREDICT`: predication/inference mode. 

46 """ 

47 

48 TRAIN = 'train' 

49 EVAL = 'eval' 

50 PREDICT = 'infer' 

51 

52 

53def is_predict(mode): 

54 return mode in [KerasModeKeys.PREDICT, EstimatorModeKeys.PREDICT] 

55 

56 

57def is_eval(mode): 

58 return mode in [KerasModeKeys.TEST, EstimatorModeKeys.EVAL] 

59 

60 

61def is_train(mode): 

62 return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN] 

63 

64 

65class ModeKeyMap(collections_abc.Mapping): 

66 """Map using ModeKeys as keys. 

67 

68 This class creates an immutable mapping from modes to values. For example, 

69 SavedModel export of Keras and Estimator models use this to map modes to their 

70 corresponding MetaGraph tags/SignatureDef keys. 

71 

72 Since this class uses modes, rather than strings, as keys, both "predict" 

73 (Keras's PREDICT ModeKey) and "infer" (Estimator's PREDICT ModeKey) map to the 

74 same value. 

75 """ 

76 

77 def __init__(self, **kwargs): 

78 self._internal_dict = {} 

79 self._keys = [] 

80 for key in kwargs: 

81 self._keys.append(key) 

82 dict_key = self._get_internal_key(key) 

83 if dict_key in self._internal_dict: 

84 raise ValueError( 

85 'Error creating ModeKeyMap. Multiple keys/values found for {} mode.' 

86 .format(dict_key)) 

87 self._internal_dict[dict_key] = kwargs[key] 

88 

89 def _get_internal_key(self, key): 

90 """Return keys used for the internal dictionary.""" 

91 if is_train(key): 

92 return KerasModeKeys.TRAIN 

93 if is_eval(key): 

94 return KerasModeKeys.TEST 

95 if is_predict(key): 

96 return KerasModeKeys.PREDICT 

97 raise ValueError('Invalid mode key: {}.'.format(key)) 

98 

99 def __getitem__(self, key): 

100 return self._internal_dict[self._get_internal_key(key)] 

101 

102 def __iter__(self): 

103 return iter(self._keys) 

104 

105 def __len__(self): 

106 return len(self._keys) 

107# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/mode_keys.py)