1from collections.abc import Iterable
2
3import sqlalchemy as sa
4from sqlalchemy.ext.hybrid import hybrid_property
5from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
6from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
7from sqlalchemy.orm.session import _state_session
8from sqlalchemy.util import set_creation_order
9
10from .exceptions import ImproperlyConfigured
11from .functions import identity
12from .functions.orm import _get_class_registry
13
14
15class GenericAttributeImpl(attributes.ScalarAttributeImpl):
16 def __init__(self, *args, **kwargs):
17 """
18 The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22,
19 adding a 'default_function' required positional argument before 'dispatch'.
20 This adjustment ensures compatibility across versions by inserting None for
21 'default_function' in versions >= 2.0.22.
22
23 Arguments received: (class, key, dispatch)
24 Required by AttributeImpl: (class, key, default_function, dispatch)
25 Setting None as default_function here.
26 """
27 # Adjust for SQLAlchemy version change
28 sqlalchemy_version = tuple(map(int, sa.__version__.split('.')))
29 if sqlalchemy_version >= (2, 0, 22):
30 args = (*args[:2], None, *args[2:])
31
32 super().__init__(*args, **kwargs)
33
34 def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
35 if self.key in dict_:
36 return dict_[self.key]
37
38 # Retrieve the session bound to the state in order to perform
39 # a lazy query for the attribute.
40 session = _state_session(state)
41 if session is None:
42 # State is not bound to a session; we cannot proceed.
43 return None
44
45 # Find class for discriminator.
46 # TODO: Perhaps optimize with some sort of lookup?
47 discriminator = self.get_state_discriminator(state)
48 target_class = _get_class_registry(state.class_).get(discriminator)
49
50 if target_class is None:
51 # Unknown discriminator; return nothing.
52 return None
53
54 id = self.get_state_id(state)
55
56 target = session.get(target_class, id)
57
58 # Return found (or not found) target.
59 return target
60
61 def get_state_discriminator(self, state):
62 discriminator = self.parent_token.discriminator
63 if isinstance(discriminator, hybrid_property):
64 return getattr(state.obj(), discriminator.__name__)
65 else:
66 return state.attrs[discriminator.key].value
67
68 def get_state_id(self, state):
69 # Lookup row with the discriminator and id.
70 return tuple(state.attrs[id.key].value for id in self.parent_token.id)
71
72 def set(
73 self,
74 state,
75 dict_,
76 initiator,
77 passive=attributes.PASSIVE_OFF,
78 check_old=None,
79 pop=False,
80 ):
81 # Set us on the state.
82 dict_[self.key] = initiator
83
84 if initiator is None:
85 # Nullify relationship args
86 for id in self.parent_token.id:
87 dict_[id.key] = None
88 dict_[self.parent_token.discriminator.key] = None
89 else:
90 # Get the primary key of the initiator and ensure we
91 # can support this assignment.
92 class_ = type(initiator)
93 mapper = class_mapper(class_)
94
95 pk = mapper.identity_key_from_instance(initiator)[1]
96
97 # Set the identifier and the discriminator.
98 discriminator = class_.__name__
99
100 for index, id in enumerate(self.parent_token.id):
101 dict_[id.key] = pk[index]
102 dict_[self.parent_token.discriminator.key] = discriminator
103
104
105class GenericRelationshipProperty(MapperProperty):
106 """A generic form of the relationship property.
107
108 Creates a 1 to many relationship between the parent model
109 and any other models using a discriminator (the table name).
110
111 :param discriminator:
112 Field to discriminate which model we are referring to.
113 :param id:
114 Field to point to the model we are referring to.
115 """
116
117 def __init__(self, discriminator, id, doc=None):
118 super().__init__()
119 self._discriminator_col = discriminator
120 self._id_cols = id
121 self._id = None
122 self._discriminator = None
123 self.doc = doc
124
125 set_creation_order(self)
126
127 def _column_to_property(self, column):
128 if isinstance(column, hybrid_property):
129 attr_key = column.__name__
130 for key, attr in self.parent.all_orm_descriptors.items():
131 if key == attr_key:
132 return attr
133 else:
134 for attr in self.parent.attrs.values():
135 if isinstance(attr, ColumnProperty):
136 if attr.columns[0].name == column.name:
137 return attr
138
139 def init(self):
140 def convert_strings(column):
141 if isinstance(column, str):
142 return self.parent.columns[column]
143 return column
144
145 self._discriminator_col = convert_strings(self._discriminator_col)
146 self._id_cols = convert_strings(self._id_cols)
147
148 if isinstance(self._id_cols, Iterable):
149 self._id_cols = list(map(convert_strings, self._id_cols))
150 else:
151 self._id_cols = [self._id_cols]
152
153 self.discriminator = self._column_to_property(self._discriminator_col)
154
155 if self.discriminator is None:
156 raise ImproperlyConfigured('Could not find discriminator descriptor.')
157
158 self.id = list(map(self._column_to_property, self._id_cols))
159
160 class Comparator(PropComparator):
161 def __init__(self, prop, parentmapper):
162 self.property = prop
163 self._parententity = parentmapper
164
165 def __eq__(self, other):
166 discriminator = type(other).__name__
167 q = self.property._discriminator_col == discriminator
168 other_id = identity(other)
169 for index, id in enumerate(self.property._id_cols):
170 q &= id == other_id[index]
171 return q
172
173 def __ne__(self, other):
174 return ~(self == other)
175
176 def is_type(self, other):
177 mapper = sa.inspect(other)
178 # Iterate through the weak sequence in order to get the actual
179 # mappers
180 class_names = [other.__name__]
181 class_names.extend(
182 [submapper.class_.__name__ for submapper in mapper._inheriting_mappers]
183 )
184
185 return self.property._discriminator_col.in_(class_names)
186
187 def instrument_class(self, mapper):
188 attributes.register_attribute(
189 mapper.class_,
190 self.key,
191 comparator=self.Comparator(self, mapper),
192 parententity=mapper,
193 doc=self.doc,
194 impl_class=GenericAttributeImpl,
195 parent_token=self,
196 )
197
198
199def generic_relationship(*args, **kwargs):
200 return GenericRelationshipProperty(*args, **kwargs)