Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/version_utils.py: 35%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities for Keras classes with v1 and v2 versions.""" 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.utils.generic_utils import LazyLoader 

21 

22# TODO(b/134426265): Switch back to single-quotes once the issue 

23# with copybara is fixed. 

24 

25training = LazyLoader("training", globals(), "keras.src.engine.training") 

26training_v1 = LazyLoader("training_v1", globals(), "keras.src.engine.training_v1") 

27base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer") 

28base_layer_v1 = LazyLoader( 

29 "base_layer_v1", globals(), "keras.src.engine.base_layer_v1" 

30) 

31callbacks = LazyLoader("callbacks", globals(), "keras.src.callbacks") 

32callbacks_v1 = LazyLoader("callbacks_v1", globals(), "keras.src.callbacks_v1") 

33 

34 

35class ModelVersionSelector: 

36 """Chooses between Keras v1 and v2 Model class.""" 

37 

38 def __new__(cls, *args, **kwargs): 

39 use_v2 = should_use_v2() 

40 cls = swap_class(cls, training.Model, training_v1.Model, use_v2) 

41 return super(ModelVersionSelector, cls).__new__(cls) 

42 

43 

44class LayerVersionSelector: 

45 """Chooses between Keras v1 and v2 Layer class.""" 

46 

47 def __new__(cls, *args, **kwargs): 

48 use_v2 = should_use_v2() 

49 cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) 

50 return super(LayerVersionSelector, cls).__new__(cls) 

51 

52 

53class TensorBoardVersionSelector: 

54 """Chooses between Keras v1 and v2 TensorBoard callback class.""" 

55 

56 def __new__(cls, *args, **kwargs): 

57 use_v2 = should_use_v2() 

58 start_cls = cls 

59 cls = swap_class( 

60 start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, use_v2 

61 ) 

62 if ( 

63 start_cls == callbacks_v1.TensorBoard 

64 and cls == callbacks.TensorBoard 

65 ): 

66 # Since the v2 class is not a subclass of the v1 class, __init__ has 

67 # to be called manually. 

68 return cls(*args, **kwargs) 

69 return super(TensorBoardVersionSelector, cls).__new__(cls) 

70 

71 

72def should_use_v2(): 

73 """Determine if v1 or v2 version should be used.""" 

74 if tf.executing_eagerly(): 

75 return True 

76 elif tf.compat.v1.executing_eagerly_outside_functions(): 

77 # Check for a v1 `wrap_function` FuncGraph. 

78 # Code inside a `wrap_function` is treated like v1 code. 

79 graph = tf.compat.v1.get_default_graph() 

80 if getattr(graph, "name", False) and graph.name.startswith( 

81 "wrapped_function" 

82 ): 

83 return False 

84 return True 

85 else: 

86 return False 

87 

88 

89def swap_class(cls, v2_cls, v1_cls, use_v2): 

90 """Swaps in v2_cls or v1_cls depending on graph mode.""" 

91 if cls == object: 

92 return cls 

93 if cls in (v2_cls, v1_cls): 

94 return v2_cls if use_v2 else v1_cls 

95 

96 # Recursively search superclasses to swap in the right Keras class. 

97 new_bases = [] 

98 for base in cls.__bases__: 

99 if ( 

100 use_v2 

101 and issubclass(base, v1_cls) 

102 # `v1_cls` often extends `v2_cls`, so it may still call `swap_class` 

103 # even if it doesn't need to. That being said, it may be the safest 

104 # not to over optimize this logic for the sake of correctness, 

105 # especially if we swap v1 & v2 classes that don't extend each 

106 # other, or when the inheritance order is different. 

107 or (not use_v2 and issubclass(base, v2_cls)) 

108 ): 

109 new_base = swap_class(base, v2_cls, v1_cls, use_v2) 

110 else: 

111 new_base = base 

112 new_bases.append(new_base) 

113 cls.__bases__ = tuple(new_bases) 

114 return cls 

115 

116 

117def disallow_legacy_graph(cls_name, method_name): 

118 if not tf.compat.v1.executing_eagerly_outside_functions(): 

119 error_msg = ( 

120 f"Calling `{cls_name}.{method_name}` in graph mode is not " 

121 f"supported when the `{cls_name}` instance was constructed with " 

122 f"eager mode enabled. Please construct your `{cls_name}` instance " 

123 f"in graph mode or call `{cls_name}.{method_name}` with " 

124 "eager mode enabled." 

125 ) 

126 raise ValueError(error_msg) 

127 

128 

129def is_v1_layer_or_model(obj): 

130 return isinstance(obj, (base_layer_v1.Layer, training_v1.Model)) 

131