Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/polymorphism/function_cache.py: 47%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Cache to manage concrete functions and their signatures.""" 

16 

17import collections 

18from typing import Any, NamedTuple, Optional 

19 

20from tensorflow.core.function.polymorphism import function_type as function_type_lib 

21from tensorflow.core.function.polymorphism import type_dispatch 

22 

23 

24class FunctionContext(NamedTuple): 

25 """Contains information regarding tf.function execution context.""" 

26 context: Any 

27 

28 

29class FunctionCache: 

30 """A container for managing concrete functions.""" 

31 

32 __slots__ = ["_primary", "_dispatch_dict", "_garbage_collectors"] 

33 

34 def __init__(self): 

35 # Maps (FunctionContext, FunctionType) to a concrete function. 

36 self._primary = collections.OrderedDict() 

37 

38 # Maps FunctionContext to a TypeDispatchTable containing FunctionTypes of 

39 # that particular context. 

40 self._dispatch_dict = {} 

41 

42 def lookup(self, context: FunctionContext, 

43 function_type: function_type_lib.FunctionType) -> Optional[Any]: 

44 """Looks up a concrete function based on the context and type.""" 

45 if context in self._dispatch_dict: 

46 dispatch_type = self._dispatch_dict[context].dispatch(function_type) 

47 if dispatch_type: 

48 return self._primary[(context, dispatch_type)] 

49 

50 return None 

51 

52 def delete(self, context: FunctionContext, 

53 function_type: function_type_lib.FunctionType) -> bool: 

54 """Deletes a concrete function given the context and type.""" 

55 if (context, function_type) not in self._primary: 

56 return False 

57 

58 del self._primary[(context, function_type)] 

59 self._dispatch_dict[context].delete(function_type) 

60 

61 return True 

62 

63 def add(self, context: FunctionContext, 

64 function_type: function_type_lib.FunctionType, 

65 concrete_fn: Any): 

66 """Adds a new concrete function alongside its key. 

67 

68 Args: 

69 context: A FunctionContext representing the current context. 

70 function_type: A FunctionType representing concrete_fn signature. 

71 concrete_fn: The concrete function to be added to the cache. 

72 """ 

73 self._primary[(context, function_type)] = concrete_fn 

74 if context not in self._dispatch_dict: 

75 self._dispatch_dict[context] = type_dispatch.TypeDispatchTable() 

76 

77 self._dispatch_dict[context].add_target(function_type) 

78 

79 def generalize( 

80 self, context: FunctionContext, 

81 function_type: function_type_lib.FunctionType 

82 ) -> function_type_lib.FunctionType: 

83 """Try to generalize a FunctionType within a FunctionContext.""" 

84 if context in self._dispatch_dict: 

85 return self._dispatch_dict[context].try_generalizing_function_type( 

86 function_type) 

87 else: 

88 return function_type 

89 

90 # TODO(b/205971333): Remove this function. 

91 def clear(self): 

92 """Removes all concrete functions from the cache.""" 

93 self._primary.clear() 

94 self._dispatch_dict.clear() 

95 

96 def values(self): 

97 """Returns a list of all `ConcreteFunction` instances held by this cache.""" 

98 return list(self._primary.values())