Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/version_utils.py: 35%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Utilities for Keras classes with v1 and v2 versions.""" 

17 

18from tensorflow.python.eager import context 

19from tensorflow.python.framework import ops 

20from tensorflow.python.keras.utils.generic_utils import LazyLoader 

21 

22# TODO(b/134426265): Switch back to single-quotes once the issue 

23# with copybara is fixed. 

24# pylint: disable=g-inconsistent-quotes 

25training = LazyLoader( 

26 "training", globals(), 

27 "tensorflow.python.keras.engine.training") 

28training_v1 = LazyLoader( 

29 "training_v1", globals(), 

30 "tensorflow.python.keras.engine.training_v1") 

31base_layer = LazyLoader( 

32 "base_layer", globals(), 

33 "tensorflow.python.keras.engine.base_layer") 

34base_layer_v1 = LazyLoader( 

35 "base_layer_v1", globals(), 

36 "tensorflow.python.keras.engine.base_layer_v1") 

37callbacks = LazyLoader( 

38 "callbacks", globals(), 

39 "tensorflow.python.keras.callbacks") 

40callbacks_v1 = LazyLoader( 

41 "callbacks_v1", globals(), 

42 "tensorflow.python.keras.callbacks_v1") 

43 

44 

45# pylint: enable=g-inconsistent-quotes 

46 

47 

48class ModelVersionSelector(object): 

49 """Chooses between Keras v1 and v2 Model class.""" 

50 

51 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 

52 use_v2 = should_use_v2() 

53 cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment 

54 return super(ModelVersionSelector, cls).__new__(cls) 

55 

56 

57class LayerVersionSelector(object): 

58 """Chooses between Keras v1 and v2 Layer class.""" 

59 

60 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 

61 use_v2 = should_use_v2() 

62 cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment 

63 return super(LayerVersionSelector, cls).__new__(cls) 

64 

65 

66class TensorBoardVersionSelector(object): 

67 """Chooses between Keras v1 and v2 TensorBoard callback class.""" 

68 

69 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 

70 use_v2 = should_use_v2() 

71 start_cls = cls 

72 cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, 

73 use_v2) 

74 if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: 

75 # Since the v2 class is not a subclass of the v1 class, __init__ has to 

76 # be called manually. 

77 return cls(*args, **kwargs) 

78 return super(TensorBoardVersionSelector, cls).__new__(cls) 

79 

80 

81def should_use_v2(): 

82 """Determine if v1 or v2 version should be used.""" 

83 if context.executing_eagerly(): 

84 return True 

85 elif ops.executing_eagerly_outside_functions(): 

86 # Check for a v1 `wrap_function` FuncGraph. 

87 # Code inside a `wrap_function` is treated like v1 code. 

88 graph = ops.get_default_graph() 

89 if (getattr(graph, "name", False) and 

90 graph.name.startswith("wrapped_function")): 

91 return False 

92 return True 

93 else: 

94 return False 

95 

96 

97def swap_class(cls, v2_cls, v1_cls, use_v2): 

98 """Swaps in v2_cls or v1_cls depending on graph mode.""" 

99 if cls == object: 

100 return cls 

101 if cls in (v2_cls, v1_cls): 

102 return v2_cls if use_v2 else v1_cls 

103 

104 # Recursively search superclasses to swap in the right Keras class. 

105 new_bases = [] 

106 for base in cls.__bases__: 

107 if ((use_v2 and issubclass(base, v1_cls) 

108 # `v1_cls` often extends `v2_cls`, so it may still call `swap_class` 

109 # even if it doesn't need to. That being said, it may be the safest 

110 # not to over optimize this logic for the sake of correctness, 

111 # especially if we swap v1 & v2 classes that don't extend each other, 

112 # or when the inheritance order is different. 

113 or (not use_v2 and issubclass(base, v2_cls)))): 

114 new_base = swap_class(base, v2_cls, v1_cls, use_v2) 

115 else: 

116 new_base = base 

117 new_bases.append(new_base) 

118 cls.__bases__ = tuple(new_bases) 

119 return cls 

120 

121 

122def disallow_legacy_graph(cls_name, method_name): 

123 if not ops.executing_eagerly_outside_functions(): 

124 error_msg = ( 

125 "Calling `{cls_name}.{method_name}` in graph mode is not supported " 

126 "when the `{cls_name}` instance was constructed with eager mode " 

127 "enabled. Please construct your `{cls_name}` instance in graph mode or" 

128 " call `{cls_name}.{method_name}` with eager mode enabled.") 

129 error_msg = error_msg.format(cls_name=cls_name, method_name=method_name) 

130 raise ValueError(error_msg) 

131 

132 

133def is_v1_layer_or_model(obj): 

134 return isinstance(obj, (base_layer_v1.Layer, training_v1.Model))