Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/utils.py: 16%
57 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image util ops."""
17import tensorflow as tf
20def get_ndims(image):
21 return image.get_shape().ndims or tf.rank(image)
24def to_4D_image(image):
25 """Convert 2/3/4D image to 4D image.
27 Args:
28 image: 2/3/4D `Tensor`.
30 Returns:
31 4D `Tensor` with the same type.
32 """
33 with tf.control_dependencies(
34 [
35 tf.debugging.assert_rank_in(
36 image, [2, 3, 4], message="`image` must be 2/3/4D tensor"
37 )
38 ]
39 ):
40 ndims = image.get_shape().ndims
41 if ndims is None:
42 return _dynamic_to_4D_image(image)
43 elif ndims == 2:
44 return image[None, :, :, None]
45 elif ndims == 3:
46 return image[None, :, :, :]
47 else:
48 return image
51def _dynamic_to_4D_image(image):
52 shape = tf.shape(image)
53 original_rank = tf.rank(image)
54 # 4D image => [N, H, W, C] or [N, C, H, W]
55 # 3D image => [1, H, W, C] or [1, C, H, W]
56 # 2D image => [1, H, W, 1]
57 left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
58 right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
59 new_shape = tf.concat(
60 [
61 tf.ones(shape=left_pad, dtype=tf.int32),
62 shape,
63 tf.ones(shape=right_pad, dtype=tf.int32),
64 ],
65 axis=0,
66 )
67 return tf.reshape(image, new_shape)
70def from_4D_image(image, ndims):
71 """Convert back to an image with `ndims` rank.
73 Args:
74 image: 4D `Tensor`.
75 ndims: The original rank of the image.
77 Returns:
78 `ndims`-D `Tensor` with the same type.
79 """
80 with tf.control_dependencies(
81 [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")]
82 ):
83 if isinstance(ndims, tf.Tensor):
84 return _dynamic_from_4D_image(image, ndims)
85 elif ndims == 2:
86 return tf.squeeze(image, [0, 3])
87 elif ndims == 3:
88 return tf.squeeze(image, [0])
89 else:
90 return image
93def _dynamic_from_4D_image(image, original_rank):
94 shape = tf.shape(image)
95 # 4D image <= [N, H, W, C] or [N, C, H, W]
96 # 3D image <= [1, H, W, C] or [1, C, H, W]
97 # 2D image <= [1, H, W, 1]
98 begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
99 end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
100 new_shape = shape[begin:end]
101 return tf.reshape(image, new_shape)
104def wrap(image):
105 """Returns `image` with an extra channel set to all 1s."""
106 shape = tf.shape(image)
107 extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
108 extended = tf.concat([image, extended_channel], 2)
109 return extended
112def unwrap(image, replace):
113 """Unwraps an image produced by wrap.
115 Where there is a 0 in the last channel for every spatial position,
116 the rest of the three channels in that spatial dimension are grayed
117 (set to 128). Operations like translate and shear on a wrapped
118 Tensor will leave 0s in empty locations. Some transformations look
119 at the intensity of values to do preprocessing, and we want these
120 empty pixels to assume the 'average' value, rather than pure black.
123 Args:
124 image: A 3D image `Tensor` with 4 channels.
125 replace: A one or three value 1D `Tensor` to fill empty pixels.
127 Returns:
128 image: A 3D image `Tensor` with 3 channels.
129 """
130 image_shape = tf.shape(image)
131 # Flatten the spatial dimensions.
132 flattened_image = tf.reshape(image, [-1, image_shape[2]])
134 # Find all pixels where the last channel is zero.
135 alpha_channel = flattened_image[:, 3]
137 replace = tf.cast(replace, image.dtype)
138 if tf.rank(replace) == 0:
139 replace = tf.expand_dims(replace, 0)
140 replace = tf.concat([replace, replace, replace], 0)
141 replace = tf.concat([replace, tf.ones([1], dtype=replace.dtype)], 0)
143 # Where they are zero, fill them in with 'replace'.
144 cond = tf.equal(alpha_channel, 1)
145 cond = tf.expand_dims(cond, 1)
146 cond = tf.concat([cond, cond, cond, cond], 1)
147 flattened_image = tf.where(cond, flattened_image, replace)
149 image = tf.reshape(flattened_image, image_shape)
150 image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
151 return image