Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/polymorphism/function_cache.py: 47%
38 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cache to manage concrete functions and their signatures."""
17import collections
18from typing import Any, NamedTuple, Optional
20from tensorflow.core.function.polymorphism import function_type as function_type_lib
21from tensorflow.core.function.polymorphism import type_dispatch
24class FunctionContext(NamedTuple):
25 """Contains information regarding tf.function execution context."""
26 context: Any
29class FunctionCache:
30 """A container for managing concrete functions."""
32 __slots__ = ["_primary", "_dispatch_dict", "_garbage_collectors"]
34 def __init__(self):
35 # Maps (FunctionContext, FunctionType) to a concrete function.
36 self._primary = collections.OrderedDict()
38 # Maps FunctionContext to a TypeDispatchTable containing FunctionTypes of
39 # that particular context.
40 self._dispatch_dict = {}
42 def lookup(self, context: FunctionContext,
43 function_type: function_type_lib.FunctionType) -> Optional[Any]:
44 """Looks up a concrete function based on the context and type."""
45 if context in self._dispatch_dict:
46 dispatch_type = self._dispatch_dict[context].dispatch(function_type)
47 if dispatch_type:
48 return self._primary[(context, dispatch_type)]
50 return None
52 def delete(self, context: FunctionContext,
53 function_type: function_type_lib.FunctionType) -> bool:
54 """Deletes a concrete function given the context and type."""
55 if (context, function_type) not in self._primary:
56 return False
58 del self._primary[(context, function_type)]
59 self._dispatch_dict[context].delete(function_type)
61 return True
63 def add(self, context: FunctionContext,
64 function_type: function_type_lib.FunctionType,
65 concrete_fn: Any):
66 """Adds a new concrete function alongside its key.
68 Args:
69 context: A FunctionContext representing the current context.
70 function_type: A FunctionType representing concrete_fn signature.
71 concrete_fn: The concrete function to be added to the cache.
72 """
73 self._primary[(context, function_type)] = concrete_fn
74 if context not in self._dispatch_dict:
75 self._dispatch_dict[context] = type_dispatch.TypeDispatchTable()
77 self._dispatch_dict[context].add_target(function_type)
79 def generalize(
80 self, context: FunctionContext,
81 function_type: function_type_lib.FunctionType
82 ) -> function_type_lib.FunctionType:
83 """Try to generalize a FunctionType within a FunctionContext."""
84 if context in self._dispatch_dict:
85 return self._dispatch_dict[context].try_generalizing_function_type(
86 function_type)
87 else:
88 return function_type
90 # TODO(b/205971333): Remove this function.
91 def clear(self):
92 """Removes all concrete functions from the cache."""
93 self._primary.clear()
94 self._dispatch_dict.clear()
96 def values(self):
97 """Returns a list of all `ConcreteFunction` instances held by this cache."""
98 return list(self._primary.values())