Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_image_ops.py: 41%
44 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image operations for RaggedTensors."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import cond
24from tensorflow.python.ops import image_ops
25from tensorflow.python.ops import map_fn
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.util import dispatch
31@dispatch.dispatch_for_api(image_ops.resize_images_v2)
32def resize_images_v2(images: ragged_tensor.RaggedTensor,
33 size,
34 method=image_ops.ResizeMethod.BILINEAR,
35 preserve_aspect_ratio=False,
36 antialias=False,
37 name=None):
38 """RaggedTensor dispatcher for tf.image.resize (tf-v2)."""
39 with ops.name_scope(name, "RaggedResizeImages", [images, size]):
40 return _resize_images(
41 image_ops.resize_images_v2,
42 images,
43 size,
44 method=method,
45 preserve_aspect_ratio=preserve_aspect_ratio,
46 antialias=antialias)
49@dispatch.dispatch_for_api(image_ops.resize_images)
50def resize_images_v1(images: ragged_tensor.RaggedTensor,
51 size,
52 method=image_ops.ResizeMethodV1.BILINEAR,
53 align_corners=False,
54 preserve_aspect_ratio=False,
55 name=None):
56 """RaggedTensor dispatcher for tf.image.resize (tf-v1)."""
57 with ops.name_scope(name, "RaggedResizeImages", [images, size]):
58 return _resize_images(
59 image_ops.resize_images,
60 images,
61 size,
62 method=method,
63 preserve_aspect_ratio=preserve_aspect_ratio,
64 align_corners=align_corners)
67def _resize_images(resize_op, images, size, **kwargs):
68 """RaggedTensor dispatcher for tf.image.resize."""
69 if images.shape.rank != 4:
70 raise ValueError(
71 "tf.image.resize: images.shape.rank must be 4 if images is ragged.")
73 # Determine the output shape (excluding the batch dimension).
74 static_batch_size = tensor_shape.dimension_value(images.shape[0])
75 size = ops.convert_to_tensor(size, dtypes.int32, "size")
76 size_as_shape = tensor_util.constant_value_as_shape(size).with_rank(2)
77 out_shape = size_as_shape + images.shape[-1:]
78 out_spec = tensor_spec.TensorSpec(out_shape, dtypes.float32)
80 def resize_one(image):
81 if isinstance(image, ragged_tensor.RaggedTensor):
82 image = image.to_tensor()
83 return resize_op(image, size, **kwargs)
85 def resize_with_map():
86 return map_fn.map_fn_v2(resize_one, images, fn_output_signature=out_spec)
88 def empty_result():
89 channels = array_ops.shape(images.flat_values)[-1:]
90 return array_ops.zeros(array_ops.concat([[0], size, channels], axis=0))
92 if static_batch_size == 0:
93 return empty_result()
94 elif static_batch_size is not None:
95 return resize_with_map()
96 else:
97 empty_batch = math_ops.equal(images.nrows(), 0)
98 return cond.cond(empty_batch, empty_result, resize_with_map)