Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_image_ops.py: 41%

44 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image operations for RaggedTensors.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.framework import tensor_spec 

21from tensorflow.python.framework import tensor_util 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import cond 

24from tensorflow.python.ops import image_ops 

25from tensorflow.python.ops import map_fn 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops.ragged import ragged_tensor 

28from tensorflow.python.util import dispatch 

29 

30 

31@dispatch.dispatch_for_api(image_ops.resize_images_v2) 

32def resize_images_v2(images: ragged_tensor.RaggedTensor, 

33 size, 

34 method=image_ops.ResizeMethod.BILINEAR, 

35 preserve_aspect_ratio=False, 

36 antialias=False, 

37 name=None): 

38 """RaggedTensor dispatcher for tf.image.resize (tf-v2).""" 

39 with ops.name_scope(name, "RaggedResizeImages", [images, size]): 

40 return _resize_images( 

41 image_ops.resize_images_v2, 

42 images, 

43 size, 

44 method=method, 

45 preserve_aspect_ratio=preserve_aspect_ratio, 

46 antialias=antialias) 

47 

48 

49@dispatch.dispatch_for_api(image_ops.resize_images) 

50def resize_images_v1(images: ragged_tensor.RaggedTensor, 

51 size, 

52 method=image_ops.ResizeMethodV1.BILINEAR, 

53 align_corners=False, 

54 preserve_aspect_ratio=False, 

55 name=None): 

56 """RaggedTensor dispatcher for tf.image.resize (tf-v1).""" 

57 with ops.name_scope(name, "RaggedResizeImages", [images, size]): 

58 return _resize_images( 

59 image_ops.resize_images, 

60 images, 

61 size, 

62 method=method, 

63 preserve_aspect_ratio=preserve_aspect_ratio, 

64 align_corners=align_corners) 

65 

66 

67def _resize_images(resize_op, images, size, **kwargs): 

68 """RaggedTensor dispatcher for tf.image.resize.""" 

69 if images.shape.rank != 4: 

70 raise ValueError( 

71 "tf.image.resize: images.shape.rank must be 4 if images is ragged.") 

72 

73 # Determine the output shape (excluding the batch dimension). 

74 static_batch_size = tensor_shape.dimension_value(images.shape[0]) 

75 size = ops.convert_to_tensor(size, dtypes.int32, "size") 

76 size_as_shape = tensor_util.constant_value_as_shape(size).with_rank(2) 

77 out_shape = size_as_shape + images.shape[-1:] 

78 out_spec = tensor_spec.TensorSpec(out_shape, dtypes.float32) 

79 

80 def resize_one(image): 

81 if isinstance(image, ragged_tensor.RaggedTensor): 

82 image = image.to_tensor() 

83 return resize_op(image, size, **kwargs) 

84 

85 def resize_with_map(): 

86 return map_fn.map_fn_v2(resize_one, images, fn_output_signature=out_spec) 

87 

88 def empty_result(): 

89 channels = array_ops.shape(images.flat_values)[-1:] 

90 return array_ops.zeros(array_ops.concat([[0], size, channels], axis=0)) 

91 

92 if static_batch_size == 0: 

93 return empty_result() 

94 elif static_batch_size is not None: 

95 return resize_with_map() 

96 else: 

97 empty_batch = math_ops.equal(images.nrows(), 0) 

98 return cond.cond(empty_batch, empty_result, resize_with_map)