1import sqlalchemy as sa
2
3
4class ProxyDict:
5 def __init__(self, parent, collection_name, mapping_attr):
6 self.parent = parent
7 self.collection_name = collection_name
8 self.child_class = mapping_attr.class_
9 self.key_name = mapping_attr.key
10 self.cache = {}
11
12 @property
13 def collection(self):
14 return getattr(self.parent, self.collection_name)
15
16 def keys(self):
17 descriptor = getattr(self.child_class, self.key_name)
18 return [x[0] for x in self.collection.values(descriptor)]
19
20 def __contains__(self, key):
21 if key in self.cache:
22 return self.cache[key] is not None
23 return self.fetch(key) is not None
24
25 def has_key(self, key):
26 return self.__contains__(key)
27
28 def fetch(self, key):
29 session = sa.orm.object_session(self.parent)
30 if session and sa.orm.util.has_identity(self.parent):
31 obj = self.collection.filter_by(**{self.key_name: key}).first()
32 self.cache[key] = obj
33 return obj
34
35 def create_new_instance(self, key):
36 value = self.child_class(**{self.key_name: key})
37 self.collection.append(value)
38 self.cache[key] = value
39 return value
40
41 def __getitem__(self, key):
42 if key in self.cache:
43 if self.cache[key] is not None:
44 return self.cache[key]
45 else:
46 value = self.fetch(key)
47 if value:
48 return value
49
50 return self.create_new_instance(key)
51
52 def __setitem__(self, key, value):
53 try:
54 existing = self[key]
55 self.collection.remove(existing)
56 except KeyError:
57 pass
58 self.collection.append(value)
59 self.cache[key] = value
60
61
62def proxy_dict(parent, collection_name, mapping_attr):
63 try:
64 parent._proxy_dicts
65 except AttributeError:
66 parent._proxy_dicts = {}
67
68 try:
69 return parent._proxy_dicts[collection_name]
70 except KeyError:
71 parent._proxy_dicts[collection_name] = ProxyDict(
72 parent, collection_name, mapping_attr
73 )
74 return parent._proxy_dicts[collection_name]
75
76
77def expire_proxy_dicts(target, context):
78 if hasattr(target, '_proxy_dicts'):
79 target._proxy_dicts = {}
80
81
82sa.event.listen(sa.orm.Mapper, 'expire', expire_proxy_dicts)