Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/default_gradient.py: 29%
34 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for computing default gradients."""
16from tensorflow.python.framework import dtypes
17from tensorflow.python.framework import tensor_shape
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import resource_variable_ops
22def get_zeros_dtype(t):
23 """Return the dtype for the default gradient for a Tensor."""
24 if t.dtype == dtypes.resource:
25 handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
26 if (handle_data is None or not handle_data.is_set or
27 len(handle_data.shape_and_type) != 1):
28 raise ValueError("Internal error: Tried to take gradients (or similar) "
29 "of a variable without handle data:\n%s" % str(t))
30 return handle_data.shape_and_type[0].dtype
31 return t.dtype
34def shape_and_dtype(t):
35 """Return the shape and dtype for the default gradient for a Tensor."""
36 if t.dtype == dtypes.resource:
37 handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
38 if (handle_data is None or not handle_data.is_set or
39 len(handle_data.shape_and_type) != 1):
40 raise ValueError("Internal error: Tried to take gradients (or similar) "
41 "of a variable without handle data:\n%s" % str(t))
42 shape_and_type = handle_data.shape_and_type[0]
43 return (tensor_shape.TensorShape(shape_and_type.shape),
44 dtypes.as_dtype(shape_and_type.dtype))
45 return t.shape, t.dtype
48def zeros_like(t):
49 """Like array_ops.zeros_like, but respects resource handles."""
50 if t.dtype == dtypes.resource:
51 return array_ops.zeros(*shape_and_dtype(t))
52 else:
53 return array_ops.zeros_like(t)
56def ones_like(t):
57 """Like array_ops.ones_like, but respects resource handles."""
58 if t.dtype == dtypes.resource:
59 return array_ops.ones(*shape_and_dtype(t))
60 else:
61 return array_ops.ones_like(t)
64def supports_default_grad(t):
65 """Whether tensor `t` supports creating a default gradient.
67 This function assumes that `t` is of a trainable type.
69 Args:
70 t: Tensor
72 Returns:
73 Bool
74 """
75 if t.dtype == dtypes.resource:
76 handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
77 if (handle_data is None or not handle_data.is_set or
78 len(handle_data.shape_and_type) != 1):
79 return False
80 return True