1from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants
2from pyrsistent._field_common import (
3 set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
4)
5from pyrsistent._pmap import PMap, pmap
6
7
8class _PRecordMeta(type):
9 def __new__(mcs, name, bases, dct):
10 set_fields(dct, bases, name='_precord_fields')
11 store_invariants(dct, bases, '_precord_invariants', '__invariant__')
12
13 dct['_precord_mandatory_fields'] = \
14 set(name for name, field in dct['_precord_fields'].items() if field.mandatory)
15
16 dct['_precord_initial_values'] = \
17 dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL)
18
19
20 dct['__slots__'] = ()
21
22 return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct)
23
24
25class PRecord(PMap, CheckedType, metaclass=_PRecordMeta):
26 """
27 A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting
28 from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element
29 access using subscript notation.
30
31 More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent
32 """
33 def __new__(cls, **kwargs):
34 # Hack total! If these two special attributes exist that means we can create
35 # ourselves. Otherwise we need to go through the Evolver to create the structures
36 # for us.
37 if '_precord_size' in kwargs and '_precord_buckets' in kwargs:
38 return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets'])
39
40 factory_fields = kwargs.pop('_factory_fields', None)
41 ignore_extra = kwargs.pop('_ignore_extra', False)
42
43 initial_values = kwargs
44 if cls._precord_initial_values:
45 initial_values = dict((k, v() if callable(v) else v)
46 for k, v in cls._precord_initial_values.items())
47 initial_values.update(kwargs)
48
49 e = _PRecordEvolver(cls, pmap(pre_size=len(cls._precord_fields)), _factory_fields=factory_fields, _ignore_extra=ignore_extra)
50 for k, v in initial_values.items():
51 e[k] = v
52
53 return e.persistent()
54
55 def set(self, *args, **kwargs):
56 """
57 Set a field in the record. This set function differs slightly from that in the PMap
58 class. First of all it accepts key-value pairs. Second it accepts multiple key-value
59 pairs to perform one, atomic, update of multiple fields.
60 """
61
62 # The PRecord set() can accept kwargs since all fields that have been declared are
63 # valid python identifiers. Also allow multiple fields to be set in one operation.
64 if args:
65 return super(PRecord, self).set(args[0], args[1])
66
67 return self.update(kwargs)
68
69 def evolver(self):
70 """
71 Returns an evolver of this object.
72 """
73 return _PRecordEvolver(self.__class__, self)
74
75 def __repr__(self):
76 return "{0}({1})".format(self.__class__.__name__,
77 ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items()))
78
79 @classmethod
80 def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
81 """
82 Factory method. Will create a new PRecord of the current type and assign the values
83 specified in kwargs.
84
85 :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
86 in the set of fields on the PRecord.
87 """
88 if isinstance(kwargs, cls):
89 return kwargs
90
91 if ignore_extra:
92 kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs}
93
94 return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs)
95
96 def __reduce__(self):
97 # Pickling support
98 return _restore_pickle, (self.__class__, dict(self),)
99
100 def serialize(self, format=None):
101 """
102 Serialize the current PRecord using custom serializer functions for fields where
103 such have been supplied.
104 """
105 return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items())
106
107
108class _PRecordEvolver(PMap._Evolver):
109 __slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra')
110
111 def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False):
112 super(_PRecordEvolver, self).__init__(original_pmap)
113 self._destination_cls = cls
114 self._invariant_error_codes = []
115 self._missing_fields = []
116 self._factory_fields = _factory_fields
117 self._ignore_extra = _ignore_extra
118
119 def __setitem__(self, key, original_value):
120 self.set(key, original_value)
121
122 def set(self, key, original_value):
123 field = self._destination_cls._precord_fields.get(key)
124 if field:
125 if self._factory_fields is None or field in self._factory_fields:
126 try:
127 if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra):
128 value = field.factory(original_value, ignore_extra=self._ignore_extra)
129 else:
130 value = field.factory(original_value)
131 except InvariantException as e:
132 self._invariant_error_codes += e.invariant_errors
133 self._missing_fields += e.missing_fields
134 return self
135 else:
136 value = original_value
137
138 check_type(self._destination_cls, field, key, value)
139
140 is_ok, error_code = field.invariant(value)
141 if not is_ok:
142 self._invariant_error_codes.append(error_code)
143
144 return super(_PRecordEvolver, self).set(key, value)
145 else:
146 raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__))
147
148 def persistent(self):
149 cls = self._destination_cls
150 is_dirty = self.is_dirty()
151 pm = super(_PRecordEvolver, self).persistent()
152 if is_dirty or not isinstance(pm, cls):
153 result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size)
154 else:
155 result = pm
156
157 if cls._precord_mandatory_fields:
158 self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f
159 in (cls._precord_mandatory_fields - set(result.keys())))
160
161 if self._invariant_error_codes or self._missing_fields:
162 raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields),
163 'Field invariant failed')
164
165 check_global_invariants(result, cls._precord_invariants)
166
167 return result