1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Contains base classes used to parse and convert arguments.
16
17Do NOT import this module directly. Import the flags package and use the
18aliases defined at the package level instead.
19"""
20
21import collections
22import csv
23import enum
24import io
25import string
26from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Type, TypeVar, Union
27from xml.dom import minidom
28
29from absl.flags import _helpers
30
31_T = TypeVar('_T')
32_ET = TypeVar('_ET', bound=enum.Enum)
33_N = TypeVar('_N', int, float)
34
35
36class _ArgumentParserCache(type):
37 """Metaclass used to cache and share argument parsers among flags."""
38
39 _instances: Dict[Any, Any] = {}
40
41 def __call__(cls, *args, **kwargs):
42 """Returns an instance of the argument parser cls.
43
44 This method overrides behavior of the __new__ methods in
45 all subclasses of ArgumentParser (inclusive). If an instance
46 for cls with the same set of arguments exists, this instance is
47 returned, otherwise a new instance is created.
48
49 If any keyword arguments are defined, or the values in args
50 are not hashable, this method always returns a new instance of
51 cls.
52
53 Args:
54 *args: Positional initializer arguments.
55 **kwargs: Initializer keyword arguments.
56
57 Returns:
58 An instance of cls, shared or new.
59 """
60 if kwargs:
61 return type.__call__(cls, *args, **kwargs)
62 else:
63 instances = cls._instances
64 key = (cls,) + tuple(args)
65 try:
66 return instances[key]
67 except KeyError:
68 # No cache entry for key exists, create a new one.
69 return instances.setdefault(key, type.__call__(cls, *args))
70 except TypeError:
71 # An object in args cannot be hashed, always return
72 # a new instance.
73 return type.__call__(cls, *args)
74
75
76class ArgumentParser(Generic[_T], metaclass=_ArgumentParserCache):
77 """Base class used to parse and convert arguments.
78
79 The :meth:`parse` method checks to make sure that the string argument is a
80 legal value and convert it to a native type. If the value cannot be
81 converted, it should throw a ``ValueError`` exception with a human
82 readable explanation of why the value is illegal.
83
84 Subclasses should also define a syntactic_help string which may be
85 presented to the user to describe the form of the legal values.
86
87 Argument parser classes must be stateless, since instances are cached
88 and shared between flags. Initializer arguments are allowed, but all
89 member variables must be derived from initializer arguments only.
90 """
91
92 syntactic_help: str = ''
93
94 def parse(self, argument: str) -> Optional[_T]:
95 """Parses the string argument and returns the native value.
96
97 By default it returns its argument unmodified.
98
99 Args:
100 argument: string argument passed in the commandline.
101
102 Raises:
103 ValueError: Raised when it fails to parse the argument.
104 TypeError: Raised when the argument has the wrong type.
105
106 Returns:
107 The parsed value in native type.
108 """
109 if not isinstance(argument, str):
110 raise TypeError('flag value must be a string, found "{}"'.format(
111 type(argument)))
112 return argument # type: ignore[return-value]
113
114 def flag_type(self) -> str:
115 """Returns a string representing the type of the flag."""
116 return 'string'
117
118 def _custom_xml_dom_elements(
119 self, doc: minidom.Document
120 ) -> List[minidom.Element]:
121 """Returns a list of minidom.Element to add additional flag information.
122
123 Args:
124 doc: minidom.Document, the DOM document it should create nodes from.
125 """
126 del doc # Unused.
127 return []
128
129
130class ArgumentSerializer(Generic[_T]):
131 """Base class for generating string representations of a flag value."""
132
133 def serialize(self, value: _T) -> str:
134 """Returns a serialized string of the value."""
135 return str(value)
136
137
138class NumericParser(ArgumentParser[_N]):
139 """Parser of numeric values.
140
141 Parsed value may be bounded to a given upper and lower bound.
142 """
143
144 lower_bound: Optional[_N]
145 upper_bound: Optional[_N]
146
147 def is_outside_bounds(self, val: _N) -> bool:
148 """Returns whether the value is outside the bounds or not."""
149 return ((self.lower_bound is not None and val < self.lower_bound) or
150 (self.upper_bound is not None and val > self.upper_bound))
151
152 def parse(self, argument: Union[str, _N]) -> _N:
153 """See base class."""
154 val = self.convert(argument)
155 if self.is_outside_bounds(val):
156 raise ValueError('%s is not %s' % (val, self.syntactic_help))
157 return val
158
159 def _custom_xml_dom_elements(
160 self, doc: minidom.Document
161 ) -> List[minidom.Element]:
162 elements = []
163 if self.lower_bound is not None:
164 elements.append(_helpers.create_xml_dom_element(
165 doc, 'lower_bound', self.lower_bound))
166 if self.upper_bound is not None:
167 elements.append(_helpers.create_xml_dom_element(
168 doc, 'upper_bound', self.upper_bound))
169 return elements
170
171 def convert(self, argument: Union[str, _N]) -> _N:
172 """Returns the correct numeric value of argument.
173
174 Subclass must implement this method, and raise TypeError if argument is not
175 string or has the right numeric type.
176
177 Args:
178 argument: string argument passed in the commandline, or the numeric type.
179
180 Raises:
181 TypeError: Raised when argument is not a string or the right numeric type.
182 ValueError: Raised when failed to convert argument to the numeric value.
183 """
184 raise NotImplementedError
185
186
187class FloatParser(NumericParser[float]):
188 """Parser of floating point values.
189
190 Parsed value may be bounded to a given upper and lower bound.
191 """
192 number_article = 'a'
193 number_name = 'number'
194 syntactic_help = ' '.join((number_article, number_name))
195
196 def __init__(
197 self,
198 lower_bound: Optional[float] = None,
199 upper_bound: Optional[float] = None,
200 ) -> None:
201 super().__init__()
202 self.lower_bound = lower_bound
203 self.upper_bound = upper_bound
204 sh = self.syntactic_help
205 if lower_bound is not None and upper_bound is not None:
206 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
207 elif lower_bound == 0:
208 sh = 'a non-negative %s' % self.number_name
209 elif upper_bound == 0:
210 sh = 'a non-positive %s' % self.number_name
211 elif upper_bound is not None:
212 sh = '%s <= %s' % (self.number_name, upper_bound)
213 elif lower_bound is not None:
214 sh = '%s >= %s' % (self.number_name, lower_bound)
215 self.syntactic_help = sh
216
217 def convert(self, argument: Union[int, float, str]) -> float:
218 """Returns the float value of argument."""
219 if (
220 (isinstance(argument, int) and not isinstance(argument, bool))
221 or isinstance(argument, float)
222 or isinstance(argument, str)
223 ):
224 return float(argument)
225 else:
226 raise TypeError(
227 'Expect argument to be a string, int, or float, found {}'.format(
228 type(argument)))
229
230 def flag_type(self) -> str:
231 """See base class."""
232 return 'float'
233
234
235class IntegerParser(NumericParser[int]):
236 """Parser of an integer value.
237
238 Parsed value may be bounded to a given upper and lower bound.
239 """
240 number_article = 'an'
241 number_name = 'integer'
242 syntactic_help = ' '.join((number_article, number_name))
243
244 def __init__(
245 self, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None
246 ) -> None:
247 super().__init__()
248 self.lower_bound = lower_bound
249 self.upper_bound = upper_bound
250 sh = self.syntactic_help
251 if lower_bound is not None and upper_bound is not None:
252 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
253 elif lower_bound == 1:
254 sh = 'a positive %s' % self.number_name
255 elif upper_bound == -1:
256 sh = 'a negative %s' % self.number_name
257 elif lower_bound == 0:
258 sh = 'a non-negative %s' % self.number_name
259 elif upper_bound == 0:
260 sh = 'a non-positive %s' % self.number_name
261 elif upper_bound is not None:
262 sh = '%s <= %s' % (self.number_name, upper_bound)
263 elif lower_bound is not None:
264 sh = '%s >= %s' % (self.number_name, lower_bound)
265 self.syntactic_help = sh
266
267 def convert(self, argument: Union[int, str]) -> int:
268 """Returns the int value of argument."""
269 if isinstance(argument, int) and not isinstance(argument, bool):
270 return argument
271 elif isinstance(argument, str):
272 base = 10
273 if len(argument) > 2 and argument[0] == '0':
274 if argument[1] == 'o':
275 base = 8
276 elif argument[1] == 'x':
277 base = 16
278 return int(argument, base)
279 else:
280 raise TypeError('Expect argument to be a string or int, found {}'.format(
281 type(argument)))
282
283 def flag_type(self) -> str:
284 """See base class."""
285 return 'int'
286
287
288class BooleanParser(ArgumentParser[bool]):
289 """Parser of boolean values."""
290
291 def parse(self, argument: Union[str, int]) -> bool:
292 """See base class."""
293 if isinstance(argument, str):
294 if argument.lower() in ('true', 't', '1'):
295 return True
296 elif argument.lower() in ('false', 'f', '0'):
297 return False
298 else:
299 raise ValueError('Non-boolean argument to boolean flag', argument)
300 elif isinstance(argument, int):
301 # Only allow bool or integer 0, 1.
302 # Note that float 1.0 == True, 0.0 == False.
303 bool_value = bool(argument)
304 if argument == bool_value:
305 return bool_value
306 else:
307 raise ValueError('Non-boolean argument to boolean flag', argument)
308
309 raise TypeError('Non-boolean argument to boolean flag', argument)
310
311 def flag_type(self) -> str:
312 """See base class."""
313 return 'bool'
314
315
316class EnumParser(ArgumentParser[str]):
317 """Parser of a string enum value (a string value from a given set)."""
318
319 def __init__(
320 self, enum_values: Iterable[str], case_sensitive: bool = True
321 ) -> None:
322 """Initializes EnumParser.
323
324 Args:
325 enum_values: [str], a non-empty list of string values in the enum.
326 case_sensitive: bool, whether or not the enum is to be case-sensitive.
327
328 Raises:
329 ValueError: When enum_values is empty.
330 """
331 if not enum_values:
332 raise ValueError(f'enum_values cannot be empty, found "{enum_values}"')
333 if isinstance(enum_values, str):
334 raise ValueError(f'enum_values cannot be a str, found "{enum_values}"')
335 super().__init__()
336 self.enum_values = list(enum_values)
337 self.case_sensitive = case_sensitive
338
339 def parse(self, argument: str) -> str:
340 """Determines validity of argument and returns the correct element of enum.
341
342 Args:
343 argument: str, the supplied flag value.
344
345 Returns:
346 The first matching element from enum_values.
347
348 Raises:
349 ValueError: Raised when argument didn't match anything in enum.
350 """
351 if self.case_sensitive:
352 if argument not in self.enum_values:
353 raise ValueError('value should be one of <%s>' %
354 '|'.join(self.enum_values))
355 else:
356 return argument
357 else:
358 if argument.upper() not in [value.upper() for value in self.enum_values]:
359 raise ValueError('value should be one of <%s>' %
360 '|'.join(self.enum_values))
361 else:
362 return [value for value in self.enum_values
363 if value.upper() == argument.upper()][0]
364
365 def flag_type(self) -> str:
366 """See base class."""
367 return 'string enum'
368
369
370class EnumClassParser(ArgumentParser[_ET]):
371 """Parser of an Enum class member."""
372
373 def __init__(
374 self, enum_class: Type[_ET], case_sensitive: bool = True
375 ) -> None:
376 """Initializes EnumParser.
377
378 Args:
379 enum_class: class, the Enum class with all possible flag values.
380 case_sensitive: bool, whether or not the enum is to be case-sensitive. If
381 False, all member names must be unique when case is ignored.
382
383 Raises:
384 TypeError: When enum_class is not a subclass of Enum.
385 ValueError: When enum_class is empty.
386 """
387 if not issubclass(enum_class, enum.Enum):
388 raise TypeError(f'{enum_class} is not a subclass of Enum.')
389 if not enum_class.__members__:
390 raise ValueError('enum_class cannot be empty, but "{}" is empty.'
391 .format(enum_class))
392 if not case_sensitive:
393 members = collections.Counter(
394 name.lower() for name in enum_class.__members__)
395 duplicate_keys = {
396 member for member, count in members.items() if count > 1
397 }
398 if duplicate_keys:
399 raise ValueError(
400 'Duplicate enum values for {} using case_sensitive=False'.format(
401 duplicate_keys))
402
403 super().__init__()
404 self.enum_class = enum_class
405 self._case_sensitive = case_sensitive
406 if case_sensitive:
407 self._member_names = tuple(enum_class.__members__)
408 else:
409 self._member_names = tuple(
410 name.lower() for name in enum_class.__members__)
411
412 @property
413 def member_names(self) -> Sequence[str]:
414 """The accepted enum names, in lowercase if not case sensitive."""
415 return self._member_names
416
417 def parse(self, argument: Union[_ET, str]) -> _ET:
418 """Determines validity of argument and returns the correct element of enum.
419
420 Args:
421 argument: str or Enum class member, the supplied flag value.
422
423 Returns:
424 The first matching Enum class member in Enum class.
425
426 Raises:
427 ValueError: Raised when argument didn't match anything in enum.
428 """
429 if isinstance(argument, self.enum_class):
430 return argument # pytype: disable=bad-return-type
431 elif not isinstance(argument, str):
432 raise ValueError(
433 '{} is not an enum member or a name of a member in {}'.format(
434 argument, self.enum_class))
435 key = EnumParser(
436 self._member_names, case_sensitive=self._case_sensitive).parse(argument)
437 if self._case_sensitive:
438 return self.enum_class[key]
439 else:
440 # If EnumParser.parse() return a value, we're guaranteed to find it
441 # as a member of the class
442 return next(value for name, value in self.enum_class.__members__.items()
443 if name.lower() == key.lower())
444
445 def flag_type(self) -> str:
446 """See base class."""
447 return 'enum class'
448
449
450class ListSerializer(Generic[_T], ArgumentSerializer[List[_T]]):
451
452 def __init__(self, list_sep: str) -> None:
453 self.list_sep = list_sep
454
455 def serialize(self, value: List[_T]) -> str:
456 """See base class."""
457 return self.list_sep.join([str(x) for x in value])
458
459
460class EnumClassListSerializer(ListSerializer[_ET]):
461 """A serializer for :class:`MultiEnumClass` flags.
462
463 This serializer simply joins the output of `EnumClassSerializer` using a
464 provided separator.
465 """
466
467 _element_serializer: 'EnumClassSerializer'
468
469 def __init__(self, list_sep: str, **kwargs) -> None:
470 """Initializes EnumClassListSerializer.
471
472 Args:
473 list_sep: String to be used as a separator when serializing
474 **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
475 individual values.
476 """
477 super().__init__(list_sep)
478 self._element_serializer = EnumClassSerializer(**kwargs)
479
480 def serialize(self, value: Union[_ET, List[_ET]]) -> str:
481 """See base class."""
482 if isinstance(value, list):
483 return self.list_sep.join(
484 self._element_serializer.serialize(x) for x in value)
485 else:
486 return self._element_serializer.serialize(value)
487
488
489class CsvListSerializer(ListSerializer[str]):
490
491 def serialize(self, value: List[str]) -> str:
492 """Serializes a list as a CSV string or unicode."""
493 output = io.StringIO()
494 writer = csv.writer(output, delimiter=self.list_sep)
495 writer.writerow([str(x) for x in value])
496 serialized_value = output.getvalue().strip()
497
498 # We need the returned value to be pure ascii or Unicodes so that
499 # when the xml help is generated they are usefully encodable.
500 return str(serialized_value)
501
502
503class EnumClassSerializer(ArgumentSerializer[_ET]):
504 """Class for generating string representations of an enum class flag value."""
505
506 def __init__(self, lowercase: bool) -> None:
507 """Initializes EnumClassSerializer.
508
509 Args:
510 lowercase: If True, enum member names are lowercased during serialization.
511 """
512 self._lowercase = lowercase
513
514 def serialize(self, value: _ET) -> str:
515 """Returns a serialized string of the Enum class value."""
516 as_string = str(value.name)
517 return as_string.lower() if self._lowercase else as_string
518
519
520class BaseListParser(ArgumentParser):
521 """Base class for a parser of lists of strings.
522
523 To extend, inherit from this class; from the subclass ``__init__``, call::
524
525 super().__init__(token, name)
526
527 where token is a character used to tokenize, and name is a description
528 of the separator.
529 """
530
531 def __init__(
532 self, token: Optional[str] = None, name: Optional[str] = None
533 ) -> None:
534 assert name
535 super().__init__()
536 self._token = token
537 self._name = name
538 self.syntactic_help = 'a %s separated list' % self._name
539
540 def parse(self, argument: str) -> List[str]:
541 """See base class."""
542 if isinstance(argument, list):
543 return argument
544 elif not argument:
545 return []
546 else:
547 return [s.strip() for s in argument.split(self._token)]
548
549 def flag_type(self) -> str:
550 """See base class."""
551 return '%s separated list of strings' % self._name
552
553
554class ListParser(BaseListParser):
555 """Parser for a comma-separated list of strings."""
556
557 def __init__(self) -> None:
558 super().__init__(',', 'comma')
559
560 def parse(self, argument: Union[str, List[str]]) -> List[str]:
561 """Parses argument as comma-separated list of strings."""
562 if isinstance(argument, list):
563 return argument
564 elif not argument:
565 return []
566 else:
567 try:
568 return [s.strip() for s in list(csv.reader([argument], strict=True))[0]]
569 except csv.Error as e:
570 # Provide a helpful report for case like
571 # --listflag="$(printf 'hello,\nworld')"
572 # IOW, list flag values containing naked newlines. This error
573 # was previously "reported" by allowing csv.Error to
574 # propagate.
575 raise ValueError('Unable to parse the value %r as a %s: %s'
576 % (argument, self.flag_type(), e))
577
578 def _custom_xml_dom_elements(
579 self, doc: minidom.Document
580 ) -> List[minidom.Element]:
581 elements = super()._custom_xml_dom_elements(doc)
582 elements.append(_helpers.create_xml_dom_element(
583 doc, 'list_separator', repr(',')))
584 return elements
585
586
587class WhitespaceSeparatedListParser(BaseListParser):
588 """Parser for a whitespace-separated list of strings."""
589
590 def __init__(self, comma_compat: bool = False) -> None:
591 """Initializer.
592
593 Args:
594 comma_compat: bool, whether to support comma as an additional separator.
595 If False then only whitespace is supported. This is intended only for
596 backwards compatibility with flags that used to be comma-separated.
597 """
598 self._comma_compat = comma_compat
599 name = 'whitespace or comma' if self._comma_compat else 'whitespace'
600 super().__init__(None, name)
601
602 def parse(self, argument: Union[str, List[str]]) -> List[str]:
603 """Parses argument as whitespace-separated list of strings.
604
605 It also parses argument as comma-separated list of strings if requested.
606
607 Args:
608 argument: string argument passed in the commandline.
609
610 Returns:
611 [str], the parsed flag value.
612 """
613 if isinstance(argument, list):
614 return argument
615 elif not argument:
616 return []
617 else:
618 if self._comma_compat:
619 argument = argument.replace(',', ' ')
620 return argument.split()
621
622 def _custom_xml_dom_elements(
623 self, doc: minidom.Document
624 ) -> List[minidom.Element]:
625 elements = super()._custom_xml_dom_elements(doc)
626 separators = list(string.whitespace)
627 if self._comma_compat:
628 separators.append(',')
629 separators.sort()
630 for sep_char in separators:
631 elements.append(_helpers.create_xml_dom_element(
632 doc, 'list_separator', repr(sep_char)))
633 return elements