Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/connected_components.py: 25%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Connected Components.""" 

16 

17import tensorflow as tf 

18 

19from tensorflow_addons.utils import types 

20from tensorflow_addons.utils.resource_loader import LazySO 

21 

22from typing import Optional, Text 

23 

24_image_so = LazySO("custom_ops/image/_image_ops.so") 

25 

26 

27@tf.function 

28def connected_components( 

29 images: types.TensorLike, name: Optional[Text] = None 

30) -> tf.Tensor: 

31 """Labels the connected components in a batch of images. 

32 

33 A component is a set of pixels in a single input image, which are 

34 all adjacent and all have the same non-zero value. The components 

35 using a squared connectivity of one (all equal entries are joined with 

36 their neighbors above,below, left, and right). Components across all 

37 images have consecutive ids 1 through n. 

38 Components are labeled according to the first pixel of the 

39 component appearing in row-major order (lexicographic order by 

40 image_index_in_batch, row, col). 

41 Zero entries all have an output id of 0. 

42 This op is equivalent with `scipy.ndimage.measurements.label` 

43 on a 2D array with the default structuring element 

44 (which is the connectivity used here). 

45 

46 Args: 

47 images: A 2D (H, W) or 3D (N, H, W) `Tensor` of image (integer, 

48 floating point and boolean types are supported). 

49 name: The name of the op. 

50 

51 Returns: 

52 Components with the same shape as `images`. 

53 entries that evaluate to False (e.g. 0/0.0f, False) in `images` have 

54 value 0, and all other entries map to a component id > 0. 

55 

56 Raises: 

57 TypeError: if `images` is not 2D or 3D. 

58 """ 

59 with tf.name_scope(name or "connected_components"): 

60 image_or_images = tf.convert_to_tensor(images, name="images") 

61 if len(image_or_images.get_shape()) == 2: 

62 images = image_or_images[None, :, :] 

63 elif len(image_or_images.get_shape()) == 3: 

64 images = image_or_images 

65 else: 

66 raise TypeError( 

67 "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" 

68 % image_or_images.get_shape() 

69 ) 

70 components = _image_so.ops.addons_image_connected_components(images) 

71 

72 # TODO(ringwalt): Component id renaming should be done in the op, 

73 # to avoid constructing multiple additional large tensors. 

74 components_flat = tf.reshape(components, [-1]) 

75 unique_ids, id_index = tf.unique(components_flat) 

76 id_is_zero = tf.where(tf.equal(unique_ids, 0))[:, 0] 

77 # Map each nonzero id to consecutive values. 

78 nonzero_consecutive_ids = ( 

79 tf.range(tf.shape(unique_ids)[0] - tf.shape(id_is_zero)[0]) + 1 

80 ) 

81 

82 def no_zero(): 

83 # No need to insert a zero into the ids. 

84 return nonzero_consecutive_ids 

85 

86 def has_zero(): 

87 # Insert a zero in the consecutive ids 

88 # where zero appears in unique_ids. 

89 # id_is_zero has length 1. 

90 zero_id_ind = tf.cast(id_is_zero[0], tf.int32) 

91 ids_before = nonzero_consecutive_ids[:zero_id_ind] 

92 ids_after = nonzero_consecutive_ids[zero_id_ind:] 

93 return tf.concat([ids_before, [0], ids_after], axis=0) 

94 

95 new_ids = tf.cond(tf.equal(tf.shape(id_is_zero)[0], 0), no_zero, has_zero) 

96 components = tf.reshape(tf.gather(new_ids, id_index), tf.shape(components)) 

97 if len(image_or_images.get_shape()) == 2: 

98 return components[0, :, :] 

99 else: 

100 return components