Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/text_dataset.py: 13%

69 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras text dataset generation utilities.""" 

16 

17import numpy as np 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.utils import dataset_utils 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@keras_export( 

27 "keras.utils.text_dataset_from_directory", 

28 "keras.preprocessing.text_dataset_from_directory", 

29 v1=[], 

30) 

31def text_dataset_from_directory( 

32 directory, 

33 labels="inferred", 

34 label_mode="int", 

35 class_names=None, 

36 batch_size=32, 

37 max_length=None, 

38 shuffle=True, 

39 seed=None, 

40 validation_split=None, 

41 subset=None, 

42 follow_links=False, 

43): 

44 """Generates a `tf.data.Dataset` from text files in a directory. 

45 

46 If your directory structure is: 

47 

48 ``` 

49 main_directory/ 

50 ...class_a/ 

51 ......a_text_1.txt 

52 ......a_text_2.txt 

53 ...class_b/ 

54 ......b_text_1.txt 

55 ......b_text_2.txt 

56 ``` 

57 

58 Then calling `text_dataset_from_directory(main_directory, 

59 labels='inferred')` will return a `tf.data.Dataset` that yields batches of 

60 texts from the subdirectories `class_a` and `class_b`, together with labels 

61 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). 

62 

63 Only `.txt` files are supported at this time. 

64 

65 Args: 

66 directory: Directory where the data is located. 

67 If `labels` is "inferred", it should contain 

68 subdirectories, each containing text files for a class. 

69 Otherwise, the directory structure is ignored. 

70 labels: Either "inferred" 

71 (labels are generated from the directory structure), 

72 None (no labels), 

73 or a list/tuple of integer labels of the same size as the number of 

74 text files found in the directory. Labels should be sorted according 

75 to the alphanumeric order of the text file paths 

76 (obtained via `os.walk(directory)` in Python). 

77 label_mode: String describing the encoding of `labels`. Options are: 

78 - 'int': means that the labels are encoded as integers 

79 (e.g. for `sparse_categorical_crossentropy` loss). 

80 - 'categorical' means that the labels are 

81 encoded as a categorical vector 

82 (e.g. for `categorical_crossentropy` loss). 

83 - 'binary' means that the labels (there can be only 2) 

84 are encoded as `float32` scalars with values 0 or 1 

85 (e.g. for `binary_crossentropy`). 

86 - None (no labels). 

87 class_names: Only valid if "labels" is "inferred". This is the explicit 

88 list of class names (must match names of subdirectories). Used 

89 to control the order of the classes 

90 (otherwise alphanumerical order is used). 

91 batch_size: Size of the batches of data. Default: 32. 

92 If `None`, the data will not be batched 

93 (the dataset will yield individual samples). 

94 max_length: Maximum size of a text string. Texts longer than this will 

95 be truncated to `max_length`. 

96 shuffle: Whether to shuffle the data. Default: True. 

97 If set to False, sorts the data in alphanumeric order. 

98 seed: Optional random seed for shuffling and transformations. 

99 validation_split: Optional float between 0 and 1, 

100 fraction of data to reserve for validation. 

101 subset: Subset of the data to return. 

102 One of "training", "validation" or "both". 

103 Only used if `validation_split` is set. 

104 When `subset="both"`, the utility returns a tuple of two datasets 

105 (the training and validation datasets respectively). 

106 follow_links: Whether to visits subdirectories pointed to by symlinks. 

107 Defaults to False. 

108 

109 Returns: 

110 A `tf.data.Dataset` object. 

111 - If `label_mode` is None, it yields `string` tensors of shape 

112 `(batch_size,)`, containing the contents of a batch of text files. 

113 - Otherwise, it yields a tuple `(texts, labels)`, where `texts` 

114 has shape `(batch_size,)` and `labels` follows the format described 

115 below. 

116 

117 Rules regarding labels format: 

118 - if `label_mode` is `int`, the labels are an `int32` tensor of shape 

119 `(batch_size,)`. 

120 - if `label_mode` is `binary`, the labels are a `float32` tensor of 

121 1s and 0s of shape `(batch_size, 1)`. 

122 - if `label_mode` is `categorical`, the labels are a `float32` tensor 

123 of shape `(batch_size, num_classes)`, representing a one-hot 

124 encoding of the class index. 

125 """ 

126 if labels not in ("inferred", None): 

127 if not isinstance(labels, (list, tuple)): 

128 raise ValueError( 

129 "`labels` argument should be a list/tuple of integer labels, " 

130 "of the same size as the number of text files in the target " 

131 "directory. If you wish to infer the labels from the " 

132 "subdirectory names in the target directory, " 

133 'pass `labels="inferred"`. ' 

134 "If you wish to get a dataset that only contains text samples " 

135 f"(no labels), pass `labels=None`. Received: labels={labels}" 

136 ) 

137 if class_names: 

138 raise ValueError( 

139 "You can only pass `class_names` if " 

140 f'`labels="inferred"`. Received: labels={labels}, and ' 

141 f"class_names={class_names}" 

142 ) 

143 if label_mode not in {"int", "categorical", "binary", None}: 

144 raise ValueError( 

145 '`label_mode` argument must be one of "int", ' 

146 '"categorical", "binary", ' 

147 f"or None. Received: label_mode={label_mode}" 

148 ) 

149 if labels is None or label_mode is None: 

150 labels = None 

151 label_mode = None 

152 dataset_utils.check_validation_split_arg( 

153 validation_split, subset, shuffle, seed 

154 ) 

155 

156 if seed is None: 

157 seed = np.random.randint(1e6) 

158 file_paths, labels, class_names = dataset_utils.index_directory( 

159 directory, 

160 labels, 

161 formats=(".txt",), 

162 class_names=class_names, 

163 shuffle=shuffle, 

164 seed=seed, 

165 follow_links=follow_links, 

166 ) 

167 

168 if label_mode == "binary" and len(class_names) != 2: 

169 raise ValueError( 

170 'When passing `label_mode="binary"`, there must be exactly 2 ' 

171 f"class_names. Received: class_names={class_names}" 

172 ) 

173 

174 if subset == "both": 

175 ( 

176 file_paths_train, 

177 labels_train, 

178 ) = dataset_utils.get_training_or_validation_split( 

179 file_paths, labels, validation_split, "training" 

180 ) 

181 ( 

182 file_paths_val, 

183 labels_val, 

184 ) = dataset_utils.get_training_or_validation_split( 

185 file_paths, labels, validation_split, "validation" 

186 ) 

187 if not file_paths_train: 

188 raise ValueError( 

189 f"No training text files found in directory {directory}. " 

190 "Allowed format: .txt" 

191 ) 

192 if not file_paths_val: 

193 raise ValueError( 

194 f"No validation text files found in directory {directory}. " 

195 "Allowed format: .txt" 

196 ) 

197 train_dataset = paths_and_labels_to_dataset( 

198 file_paths=file_paths_train, 

199 labels=labels_train, 

200 label_mode=label_mode, 

201 num_classes=len(class_names), 

202 max_length=max_length, 

203 ) 

204 val_dataset = paths_and_labels_to_dataset( 

205 file_paths=file_paths_val, 

206 labels=labels_val, 

207 label_mode=label_mode, 

208 num_classes=len(class_names), 

209 max_length=max_length, 

210 ) 

211 

212 train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) 

213 val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) 

214 if batch_size is not None: 

215 if shuffle: 

216 # Shuffle locally at each iteration 

217 train_dataset = train_dataset.shuffle( 

218 buffer_size=batch_size * 8, seed=seed 

219 ) 

220 train_dataset = train_dataset.batch(batch_size) 

221 val_dataset = val_dataset.batch(batch_size) 

222 else: 

223 if shuffle: 

224 train_dataset = train_dataset.shuffle( 

225 buffer_size=1024, seed=seed 

226 ) 

227 # Users may need to reference `class_names`. 

228 train_dataset.class_names = class_names 

229 val_dataset.class_names = class_names 

230 dataset = [train_dataset, val_dataset] 

231 else: 

232 file_paths, labels = dataset_utils.get_training_or_validation_split( 

233 file_paths, labels, validation_split, subset 

234 ) 

235 if not file_paths: 

236 raise ValueError( 

237 f"No text files found in directory {directory}. " 

238 "Allowed format: .txt" 

239 ) 

240 dataset = paths_and_labels_to_dataset( 

241 file_paths=file_paths, 

242 labels=labels, 

243 label_mode=label_mode, 

244 num_classes=len(class_names), 

245 max_length=max_length, 

246 ) 

247 dataset = dataset.prefetch(tf.data.AUTOTUNE) 

248 if batch_size is not None: 

249 if shuffle: 

250 # Shuffle locally at each iteration 

251 dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) 

252 dataset = dataset.batch(batch_size) 

253 else: 

254 if shuffle: 

255 dataset = dataset.shuffle(buffer_size=1024, seed=seed) 

256 # Users may need to reference `class_names`. 

257 dataset.class_names = class_names 

258 return dataset 

259 

260 

261def paths_and_labels_to_dataset( 

262 file_paths, labels, label_mode, num_classes, max_length 

263): 

264 """Constructs a dataset of text strings and labels.""" 

265 path_ds = tf.data.Dataset.from_tensor_slices(file_paths) 

266 string_ds = path_ds.map( 

267 lambda x: path_to_string_content(x, max_length), 

268 num_parallel_calls=tf.data.AUTOTUNE, 

269 ) 

270 if label_mode: 

271 label_ds = dataset_utils.labels_to_dataset( 

272 labels, label_mode, num_classes 

273 ) 

274 string_ds = tf.data.Dataset.zip((string_ds, label_ds)) 

275 return string_ds 

276 

277 

278def path_to_string_content(path, max_length): 

279 txt = tf.io.read_file(path) 

280 if max_length is not None: 

281 txt = tf.compat.v1.strings.substr(txt, 0, max_length) 

282 return txt 

283