Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/cardinality.py: 68%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Cardinality analysis of `Dataset` objects.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import gen_dataset_ops 

20from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

21from tensorflow.python.util.tf_export import tf_export 

22 

23 

24INFINITE = -1 

25UNKNOWN = -2 

26tf_export("data.experimental.INFINITE_CARDINALITY").export_constant( 

27 __name__, "INFINITE") 

28tf_export("data.experimental.UNKNOWN_CARDINALITY").export_constant( 

29 __name__, "UNKNOWN") 

30 

31 

32# TODO(b/157691652): Deprecate this method after migrating users to the new API. 

33@tf_export("data.experimental.cardinality") 

34def cardinality(dataset): 

35 """Returns the cardinality of `dataset`, if known. 

36 

37 The operation returns the cardinality of `dataset`. The operation may return 

38 `tf.data.experimental.INFINITE_CARDINALITY` if `dataset` contains an infinite 

39 number of elements or `tf.data.experimental.UNKNOWN_CARDINALITY` if the 

40 analysis fails to determine the number of elements in `dataset` (e.g. when the 

41 dataset source is a file). 

42 

43 >>> dataset = tf.data.Dataset.range(42) 

44 >>> print(tf.data.experimental.cardinality(dataset).numpy()) 

45 42 

46 >>> dataset = dataset.repeat() 

47 >>> cardinality = tf.data.experimental.cardinality(dataset) 

48 >>> print((cardinality == tf.data.experimental.INFINITE_CARDINALITY).numpy()) 

49 True 

50 >>> dataset = dataset.filter(lambda x: True) 

51 >>> cardinality = tf.data.experimental.cardinality(dataset) 

52 >>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy()) 

53 True 

54 

55 Args: 

56 dataset: A `tf.data.Dataset` for which to determine cardinality. 

57 

58 Returns: 

59 A scalar `tf.int64` `Tensor` representing the cardinality of `dataset`. If 

60 the cardinality is infinite or unknown, the operation returns the named 

61 constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively. 

62 """ 

63 

64 return gen_dataset_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access 

65 

66 

67@tf_export("data.experimental.assert_cardinality") 

68def assert_cardinality(expected_cardinality): 

69 """Asserts the cardinality of the input dataset. 

70 

71 NOTE: The following assumes that "examples.tfrecord" contains 42 records. 

72 

73 >>> dataset = tf.data.TFRecordDataset("examples.tfrecord") 

74 >>> cardinality = tf.data.experimental.cardinality(dataset) 

75 >>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy()) 

76 True 

77 >>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(42)) 

78 >>> print(tf.data.experimental.cardinality(dataset).numpy()) 

79 42 

80 

81 Args: 

82 expected_cardinality: The expected cardinality of the input dataset. 

83 

84 Returns: 

85 A `Dataset` transformation function, which can be passed to 

86 `tf.data.Dataset.apply`. 

87 

88 Raises: 

89 FailedPreconditionError: The assertion is checked at runtime (when iterating 

90 the dataset) and an error is raised if the actual and expected cardinality 

91 differ. 

92 """ 

93 def _apply_fn(dataset): 

94 return _AssertCardinalityDataset(dataset, expected_cardinality) 

95 

96 return _apply_fn 

97 

98 

99class _AssertCardinalityDataset(dataset_ops.UnaryUnchangedStructureDataset): 

100 """A `Dataset` that assert the cardinality of its input.""" 

101 

102 def __init__(self, input_dataset, expected_cardinality): 

103 self._input_dataset = input_dataset 

104 self._expected_cardinality = ops.convert_to_tensor( 

105 expected_cardinality, dtype=dtypes.int64, name="expected_cardinality") 

106 

107 # pylint: enable=protected-access 

108 variant_tensor = ged_ops.assert_cardinality_dataset( 

109 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

110 self._expected_cardinality, 

111 **self._flat_structure) 

112 super(_AssertCardinalityDataset, self).__init__(input_dataset, 

113 variant_tensor)