1import logging
2import re
3
4from collections import OrderedDict
5from inspect import isclass
6
7from .errors import RestError
8
9log = logging.getLogger(__name__)
10
11LEXER = re.compile(r"\{|\}|\,|[\w_:\-\*]+")
12
13
14class MaskError(RestError):
15 """Raised when an error occurs on mask"""
16
17 pass
18
19
20class ParseError(MaskError):
21 """Raised when the mask parsing failed"""
22
23 pass
24
25
26class Mask(OrderedDict):
27 """
28 Hold a parsed mask.
29
30 :param str|dict|Mask mask: A mask, parsed or not
31 :param bool skip: If ``True``, missing fields won't appear in result
32 """
33
34 def __init__(self, mask=None, skip=False, **kwargs):
35 self.skip = skip
36 if isinstance(mask, str):
37 super(Mask, self).__init__()
38 self.parse(mask)
39 elif isinstance(mask, (dict, OrderedDict)):
40 super(Mask, self).__init__(mask, **kwargs)
41 else:
42 self.skip = skip
43 super(Mask, self).__init__(**kwargs)
44
45 def parse(self, mask):
46 """
47 Parse a fields mask.
48 Expect something in the form::
49
50 {field,nested{nested_field,another},last}
51
52 External brackets are optionals so it can also be written::
53
54 field,nested{nested_field,another},last
55
56 All extras characters will be ignored.
57
58 :param str mask: the mask string to parse
59 :raises ParseError: when a mask is unparseable/invalid
60
61 """
62 if not mask:
63 return
64
65 mask = self.clean(mask)
66 fields = self
67 previous = None
68 stack = []
69
70 for token in LEXER.findall(mask):
71 if token == "{":
72 if previous not in fields:
73 raise ParseError("Unexpected opening bracket")
74 fields[previous] = Mask(skip=self.skip)
75 stack.append(fields)
76 fields = fields[previous]
77 elif token == "}":
78 if not stack:
79 raise ParseError("Unexpected closing bracket")
80 fields = stack.pop()
81 elif token == ",":
82 if previous in (",", "{", None):
83 raise ParseError("Unexpected comma")
84 else:
85 fields[token] = True
86
87 previous = token
88
89 if stack:
90 raise ParseError("Missing closing bracket")
91
92 def clean(self, mask):
93 """Remove unnecessary characters"""
94 mask = mask.replace("\n", "").strip()
95 # External brackets are optional
96 if mask[0] == "{":
97 if mask[-1] != "}":
98 raise ParseError("Missing closing bracket")
99 mask = mask[1:-1]
100 return mask
101
102 def apply(self, data):
103 """
104 Apply a fields mask to the data.
105
106 :param data: The data or model to apply mask on
107 :raises MaskError: when unable to apply the mask
108
109 """
110 from . import fields
111
112 # Should handle lists
113 if isinstance(data, (list, tuple, set)):
114 return [self.apply(d) for d in data]
115 elif isinstance(data, (fields.Nested, fields.List, fields.Polymorph)):
116 return data.clone(self)
117 elif type(data) == fields.Raw:
118 return fields.Raw(default=data.default, attribute=data.attribute, mask=self)
119 elif data == fields.Raw:
120 return fields.Raw(mask=self)
121 elif (
122 isinstance(data, fields.Raw)
123 or isclass(data)
124 and issubclass(data, fields.Raw)
125 ):
126 # Not possible to apply a mask on these remaining fields types
127 raise MaskError("Mask is inconsistent with model")
128 # Should handle objects
129 elif not isinstance(data, (dict, OrderedDict)) and hasattr(data, "__dict__"):
130 data = data.__dict__
131
132 return self.filter_data(data)
133
134 def filter_data(self, data):
135 """
136 Handle the data filtering given a parsed mask
137
138 :param dict data: the raw data to filter
139 :param list mask: a parsed mask to filter against
140 :param bool skip: whether or not to skip missing fields
141
142 """
143 out = {}
144 for field, content in self.items():
145 if field == "*":
146 continue
147 elif isinstance(content, Mask):
148 nested = data.get(field, None)
149 if self.skip and nested is None:
150 continue
151 elif nested is None:
152 out[field] = None
153 else:
154 out[field] = content.apply(nested)
155 elif self.skip and field not in data:
156 continue
157 else:
158 out[field] = data.get(field, None)
159
160 if "*" in self.keys():
161 for key, value in data.items():
162 if key not in out:
163 out[key] = value
164 return out
165
166 def __str__(self):
167 return "{{{0}}}".format(
168 ",".join(
169 [
170 "".join((k, str(v))) if isinstance(v, Mask) else k
171 for k, v in self.items()
172 ]
173 )
174 )
175
176
177def apply(data, mask, skip=False):
178 """
179 Apply a fields mask to the data.
180
181 :param data: The data or model to apply mask on
182 :param str|Mask mask: the mask (parsed or not) to apply on data
183 :param bool skip: If rue, missing field won't appear in result
184 :raises MaskError: when unable to apply the mask
185
186 """
187 return Mask(mask, skip).apply(data)