Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/connected_components.py: 25%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Connected Components."""
17import tensorflow as tf
19from tensorflow_addons.utils import types
20from tensorflow_addons.utils.resource_loader import LazySO
22from typing import Optional, Text
24_image_so = LazySO("custom_ops/image/_image_ops.so")
27@tf.function
28def connected_components(
29 images: types.TensorLike, name: Optional[Text] = None
30) -> tf.Tensor:
31 """Labels the connected components in a batch of images.
33 A component is a set of pixels in a single input image, which are
34 all adjacent and all have the same non-zero value. The components
35 using a squared connectivity of one (all equal entries are joined with
36 their neighbors above,below, left, and right). Components across all
37 images have consecutive ids 1 through n.
38 Components are labeled according to the first pixel of the
39 component appearing in row-major order (lexicographic order by
40 image_index_in_batch, row, col).
41 Zero entries all have an output id of 0.
42 This op is equivalent with `scipy.ndimage.measurements.label`
43 on a 2D array with the default structuring element
44 (which is the connectivity used here).
46 Args:
47 images: A 2D (H, W) or 3D (N, H, W) `Tensor` of image (integer,
48 floating point and boolean types are supported).
49 name: The name of the op.
51 Returns:
52 Components with the same shape as `images`.
53 entries that evaluate to False (e.g. 0/0.0f, False) in `images` have
54 value 0, and all other entries map to a component id > 0.
56 Raises:
57 TypeError: if `images` is not 2D or 3D.
58 """
59 with tf.name_scope(name or "connected_components"):
60 image_or_images = tf.convert_to_tensor(images, name="images")
61 if len(image_or_images.get_shape()) == 2:
62 images = image_or_images[None, :, :]
63 elif len(image_or_images.get_shape()) == 3:
64 images = image_or_images
65 else:
66 raise TypeError(
67 "images should have rank 2 (HW) or 3 (NHW). Static shape is %s"
68 % image_or_images.get_shape()
69 )
70 components = _image_so.ops.addons_image_connected_components(images)
72 # TODO(ringwalt): Component id renaming should be done in the op,
73 # to avoid constructing multiple additional large tensors.
74 components_flat = tf.reshape(components, [-1])
75 unique_ids, id_index = tf.unique(components_flat)
76 id_is_zero = tf.where(tf.equal(unique_ids, 0))[:, 0]
77 # Map each nonzero id to consecutive values.
78 nonzero_consecutive_ids = (
79 tf.range(tf.shape(unique_ids)[0] - tf.shape(id_is_zero)[0]) + 1
80 )
82 def no_zero():
83 # No need to insert a zero into the ids.
84 return nonzero_consecutive_ids
86 def has_zero():
87 # Insert a zero in the consecutive ids
88 # where zero appears in unique_ids.
89 # id_is_zero has length 1.
90 zero_id_ind = tf.cast(id_is_zero[0], tf.int32)
91 ids_before = nonzero_consecutive_ids[:zero_id_ind]
92 ids_after = nonzero_consecutive_ids[zero_id_ind:]
93 return tf.concat([ids_before, [0], ids_after], axis=0)
95 new_ids = tf.cond(tf.equal(tf.shape(id_is_zero)[0], 0), no_zero, has_zero)
96 components = tf.reshape(tf.gather(new_ids, id_index), tf.shape(components))
97 if len(image_or_images.get_shape()) == 2:
98 return components[0, :, :]
99 else:
100 return components