1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import codecs
12import gzip
13import json
14import os
15import sys
16import tempfile
17import unicodedata
18from functools import lru_cache
19from typing import Dict, Tuple
20
21from hypothesis.configuration import storage_directory
22from hypothesis.errors import InvalidArgument
23from hypothesis.internal.intervalsets import IntervalSet
24
25intervals = Tuple[Tuple[int, int], ...]
26cache_type = Dict[Tuple[Tuple[str, ...], int, int, intervals], IntervalSet]
27
28
29def charmap_file(fname="charmap"):
30 return storage_directory(
31 "unicode_data", unicodedata.unidata_version, f"{fname}.json.gz"
32 )
33
34
35_charmap = None
36
37
38def charmap():
39 """Return a dict that maps a Unicode category, to a tuple of 2-tuples
40 covering the codepoint intervals for characters in that category.
41
42 >>> charmap()['Co']
43 ((57344, 63743), (983040, 1048573), (1048576, 1114109))
44 """
45 global _charmap
46 # Best-effort caching in the face of missing files and/or unwritable
47 # filesystems is fairly simple: check if loaded, else try loading,
48 # else calculate and try writing the cache.
49 if _charmap is None:
50 f = charmap_file()
51 try:
52 with gzip.GzipFile(f, "rb") as i:
53 tmp_charmap = dict(json.load(i))
54
55 except Exception:
56 # This loop is reduced to using only local variables for performance;
57 # indexing and updating containers is a ~3x slowdown. This doesn't fix
58 # https://github.com/HypothesisWorks/hypothesis/issues/2108 but it helps.
59 category = unicodedata.category # Local variable -> ~20% speedup!
60 tmp_charmap = {}
61 last_cat = category(chr(0))
62 last_start = 0
63 for i in range(1, sys.maxunicode + 1):
64 cat = category(chr(i))
65 if cat != last_cat:
66 tmp_charmap.setdefault(last_cat, []).append([last_start, i - 1])
67 last_cat, last_start = cat, i
68 tmp_charmap.setdefault(last_cat, []).append([last_start, sys.maxunicode])
69
70 try:
71 # Write the Unicode table atomically
72 tmpdir = storage_directory("tmp")
73 tmpdir.mkdir(exist_ok=True, parents=True)
74 fd, tmpfile = tempfile.mkstemp(dir=tmpdir)
75 os.close(fd)
76 # Explicitly set the mtime to get reproducible output
77 with gzip.GzipFile(tmpfile, "wb", mtime=1) as o:
78 result = json.dumps(sorted(tmp_charmap.items()))
79 o.write(result.encode())
80
81 os.renames(tmpfile, f)
82 except Exception:
83 pass
84
85 # convert between lists and tuples
86 _charmap = {
87 k: tuple(tuple(pair) for pair in pairs) for k, pairs in tmp_charmap.items()
88 }
89 # each value is a tuple of 2-tuples (that is, tuples of length 2)
90 # and that both elements of that tuple are integers.
91 for vs in _charmap.values():
92 ints = list(sum(vs, ()))
93 assert all(isinstance(x, int) for x in ints)
94 assert ints == sorted(ints)
95 assert all(len(tup) == 2 for tup in vs)
96
97 assert _charmap is not None
98 return _charmap
99
100
101@lru_cache(maxsize=None)
102def intervals_from_codec(codec_name: str) -> IntervalSet: # pragma: no cover
103 """Return an IntervalSet of characters which are part of this codec."""
104 assert codec_name == codecs.lookup(codec_name).name
105 fname = charmap_file(f"codec-{codec_name}")
106 try:
107 with gzip.GzipFile(fname) as gzf:
108 encodable_intervals = json.load(gzf)
109
110 except Exception:
111 # This loop is kinda slow, but hopefully we don't need to do it very often!
112 encodable_intervals = []
113 for i in range(sys.maxunicode + 1):
114 try:
115 chr(i).encode(codec_name)
116 except Exception: # usually _but not always_ UnicodeEncodeError
117 pass
118 else:
119 encodable_intervals.append((i, i))
120
121 res = IntervalSet(encodable_intervals)
122 res = res.union(res)
123 try:
124 # Write the Unicode table atomically
125 tmpdir = storage_directory("tmp")
126 tmpdir.mkdir(exist_ok=True, parents=True)
127 fd, tmpfile = tempfile.mkstemp(dir=tmpdir)
128 os.close(fd)
129 # Explicitly set the mtime to get reproducible output
130 with gzip.GzipFile(tmpfile, "wb", mtime=1) as o:
131 o.write(json.dumps(res.intervals).encode())
132 os.renames(tmpfile, fname)
133 except Exception:
134 pass
135 return res
136
137
138_categories = None
139
140
141def categories():
142 """Return a tuple of Unicode categories in a normalised order.
143
144 >>> categories() # doctest: +ELLIPSIS
145 ('Zl', 'Zp', 'Co', 'Me', 'Pc', ..., 'Cc', 'Cs')
146 """
147 global _categories
148 if _categories is None:
149 cm = charmap()
150 _categories = sorted(cm.keys(), key=lambda c: len(cm[c]))
151 _categories.remove("Cc") # Other, Control
152 _categories.remove("Cs") # Other, Surrogate
153 _categories.append("Cc")
154 _categories.append("Cs")
155 return tuple(_categories)
156
157
158def as_general_categories(cats, name="cats"):
159 """Return a tuple of Unicode categories in a normalised order.
160
161 This function expands one-letter designations of a major class to include
162 all subclasses:
163
164 >>> as_general_categories(['N'])
165 ('Nd', 'Nl', 'No')
166
167 See section 4.5 of the Unicode standard for more on classes:
168 https://www.unicode.org/versions/Unicode10.0.0/ch04.pdf
169
170 If the collection ``cats`` includes any elements that do not represent a
171 major class or a class with subclass, a deprecation warning is raised.
172 """
173 if cats is None:
174 return None
175 major_classes = ("L", "M", "N", "P", "S", "Z", "C")
176 cs = categories()
177 out = set(cats)
178 for c in cats:
179 if c in major_classes:
180 out.discard(c)
181 out.update(x for x in cs if x.startswith(c))
182 elif c not in cs:
183 raise InvalidArgument(
184 f"In {name}={cats!r}, {c!r} is not a valid Unicode category."
185 )
186 return tuple(c for c in cs if c in out)
187
188
189category_index_cache = {(): ()}
190
191
192def _category_key(cats):
193 """Return a normalised tuple of all Unicode categories that are in
194 `include`, but not in `exclude`.
195
196 If include is None then default to including all categories.
197 Any item in include that is not a unicode character will be excluded.
198
199 >>> _category_key(exclude=['So'], include=['Lu', 'Me', 'Cs', 'So'])
200 ('Me', 'Lu', 'Cs')
201 """
202 cs = categories()
203 if cats is None:
204 cats = set(cs)
205 return tuple(c for c in cs if c in cats)
206
207
208def _query_for_key(key):
209 """Return a tuple of codepoint intervals covering characters that match one
210 or more categories in the tuple of categories `key`.
211
212 >>> _query_for_key(categories())
213 ((0, 1114111),)
214 >>> _query_for_key(('Zl', 'Zp', 'Co'))
215 ((8232, 8233), (57344, 63743), (983040, 1048573), (1048576, 1114109))
216 """
217 try:
218 return category_index_cache[key]
219 except KeyError:
220 pass
221 assert key
222 if set(key) == set(categories()):
223 result = IntervalSet([(0, sys.maxunicode)])
224 else:
225 result = IntervalSet(_query_for_key(key[:-1])).union(
226 IntervalSet(charmap()[key[-1]])
227 )
228 assert isinstance(result, IntervalSet)
229 category_index_cache[key] = result.intervals
230 return result.intervals
231
232
233limited_category_index_cache: cache_type = {}
234
235
236def query(
237 *,
238 categories=None,
239 min_codepoint=None,
240 max_codepoint=None,
241 include_characters="",
242 exclude_characters="",
243):
244 """Return a tuple of intervals covering the codepoints for all characters
245 that meet the criteria.
246
247 >>> query()
248 ((0, 1114111),)
249 >>> query(min_codepoint=0, max_codepoint=128)
250 ((0, 128),)
251 >>> query(min_codepoint=0, max_codepoint=128, categories=['Lu'])
252 ((65, 90),)
253 >>> query(min_codepoint=0, max_codepoint=128, categories=['Lu'],
254 ... include_characters='☃')
255 ((65, 90), (9731, 9731))
256 """
257 if min_codepoint is None:
258 min_codepoint = 0
259 if max_codepoint is None:
260 max_codepoint = sys.maxunicode
261 catkey = _category_key(categories)
262 character_intervals = IntervalSet.from_string(include_characters or "")
263 exclude_intervals = IntervalSet.from_string(exclude_characters or "")
264 qkey = (
265 catkey,
266 min_codepoint,
267 max_codepoint,
268 character_intervals.intervals,
269 exclude_intervals.intervals,
270 )
271 try:
272 return limited_category_index_cache[qkey]
273 except KeyError:
274 pass
275 base = _query_for_key(catkey)
276 result = []
277 for u, v in base:
278 if v >= min_codepoint and u <= max_codepoint:
279 result.append((max(u, min_codepoint), min(v, max_codepoint)))
280 result = (IntervalSet(result) | character_intervals) - exclude_intervals
281 limited_category_index_cache[qkey] = result
282 return result