1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for the TFMethodTarget Class."""
16
17import weakref
18
19from tensorflow.python.util import tf_inspect
20
21
22# When a method is bound to objects of this type, it allows AutoGraph to
23# recover a weak reference the original method's self pointer, so that it can
24# execute it consistent with class_method_to_instance_method's
25# bound_method_wrapper.
26# TODO(b/119246461): This is not pretty. Use a descriptor instead?
27class TfMethodTarget:
28 """Binding target for methods replaced by function and defun."""
29
30 __slots__ = ("weakrefself_target__", "weakrefself_func__")
31
32 def __init__(self, target, original_python_function):
33 self.weakrefself_target__ = target
34 self.weakrefself_func__ = weakref.ref(original_python_function)
35
36 @property
37 def target(self):
38 return self.weakrefself_target__()
39
40 @property
41 def target_class(self):
42 true_self = self.weakrefself_target__()
43 if tf_inspect.isclass(true_self):
44 # Class method
45 return true_self
46 else:
47 return true_self.__class__
48
49 def call(self, args, kwargs):
50 wrapped_fn = self.weakrefself_func__()
51 return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)