Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/feature_column/utils.py: 23%

61 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Defines functions common to multiple feature column files.""" 

16 

17import six 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.util import nest 

24 

25 

26def sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): 

27 """Returns a [batch_size] Tensor with per-example sequence length.""" 

28 with ops.name_scope(None, 'sequence_length') as name_scope: 

29 row_ids = sp_tensor.indices[:, 0] 

30 column_ids = sp_tensor.indices[:, 1] 

31 # Add one to convert column indices to element length 

32 column_ids += array_ops.ones_like(column_ids) 

33 # Get the number of elements we will have per example/row 

34 seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) 

35 

36 # The raw values are grouped according to num_elements; 

37 # how many entities will we have after grouping? 

38 # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), 

39 # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, 

40 # these will get grouped, and the final seq_length is [1, 1] 

41 seq_length = math_ops.cast( 

42 math_ops.ceil(seq_length / num_elements), dtypes.int64) 

43 

44 # If the last n rows do not have ids, seq_length will have shape 

45 # [batch_size - n]. Pad the remaining values with zeros. 

46 n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] 

47 padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) 

48 return array_ops.concat([seq_length, padding], axis=0, name=name_scope) 

49 

50 

51def assert_string_or_int(dtype, prefix): 

52 if (dtype != dtypes.string) and (not dtype.is_integer): 

53 raise ValueError( 

54 '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) 

55 

56 

57def assert_key_is_string(key): 

58 if not isinstance(key, six.string_types): 

59 raise ValueError( 

60 'key must be a string. Got: type {}. Given key: {}.'.format( 

61 type(key), key)) 

62 

63 

64def check_default_value(shape, default_value, dtype, key): 

65 """Returns default value as tuple if it's valid, otherwise raises errors. 

66 

67 This function verifies that `default_value` is compatible with both `shape` 

68 and `dtype`. If it is not compatible, it raises an error. If it is compatible, 

69 it casts default_value to a tuple and returns it. `key` is used only 

70 for error message. 

71 

72 Args: 

73 shape: An iterable of integers specifies the shape of the `Tensor`. 

74 default_value: If a single value is provided, the same value will be applied 

75 as the default value for every item. If an iterable of values is 

76 provided, the shape of the `default_value` should be equal to the given 

77 `shape`. 

78 dtype: defines the type of values. Default value is `tf.float32`. Must be a 

79 non-quantized, real integer or floating point type. 

80 key: Column name, used only for error messages. 

81 

82 Returns: 

83 A tuple which will be used as default value. 

84 

85 Raises: 

86 TypeError: if `default_value` is an iterable but not compatible with `shape` 

87 TypeError: if `default_value` is not compatible with `dtype`. 

88 ValueError: if `dtype` is not convertible to `tf.float32`. 

89 """ 

90 if default_value is None: 

91 return None 

92 

93 if isinstance(default_value, int): 

94 return _create_tuple(shape, default_value) 

95 

96 if isinstance(default_value, float) and dtype.is_floating: 

97 return _create_tuple(shape, default_value) 

98 

99 if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays 

100 default_value = default_value.tolist() 

101 

102 if nest.is_nested(default_value): 

103 if not _is_shape_and_default_value_compatible(default_value, shape): 

104 raise ValueError( 

105 'The shape of default_value must be equal to given shape. ' 

106 'default_value: {}, shape: {}, key: {}'.format( 

107 default_value, shape, key)) 

108 # Check if the values in the list are all integers or are convertible to 

109 # floats. 

110 is_list_all_int = all( 

111 isinstance(v, int) for v in nest.flatten(default_value)) 

112 is_list_has_float = any( 

113 isinstance(v, float) for v in nest.flatten(default_value)) 

114 if is_list_all_int: 

115 return _as_tuple(default_value) 

116 if is_list_has_float and dtype.is_floating: 

117 return _as_tuple(default_value) 

118 raise TypeError('default_value must be compatible with dtype. ' 

119 'default_value: {}, dtype: {}, key: {}'.format( 

120 default_value, dtype, key)) 

121 

122 

123def _create_tuple(shape, value): 

124 """Returns a tuple with given shape and filled with value.""" 

125 if shape: 

126 return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])]) 

127 return value 

128 

129 

130def _as_tuple(value): 

131 if not nest.is_nested(value): 

132 return value 

133 return tuple([_as_tuple(v) for v in value]) 

134 

135 

136def _is_shape_and_default_value_compatible(default_value, shape): 

137 """Verifies compatibility of shape and default_value.""" 

138 # Invalid condition: 

139 # * if default_value is not a scalar and shape is empty 

140 # * or if default_value is an iterable and shape is not empty 

141 if nest.is_nested(default_value) != bool(shape): 

142 return False 

143 if not shape: 

144 return True 

145 if len(default_value) != shape[0]: 

146 return False 

147 for i in range(shape[0]): 

148 if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): 

149 return False 

150 return True