Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy_utils/generic.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

104 statements  

1from collections.abc import Iterable 

2 

3import sqlalchemy as sa 

4from sqlalchemy.ext.hybrid import hybrid_property 

5from sqlalchemy.orm import attributes, class_mapper, ColumnProperty 

6from sqlalchemy.orm.interfaces import MapperProperty, PropComparator 

7from sqlalchemy.orm.session import _state_session 

8from sqlalchemy.util import set_creation_order 

9 

10from .exceptions import ImproperlyConfigured 

11from .functions import identity 

12from .functions.orm import _get_class_registry 

13 

14 

15class GenericAttributeImpl(attributes.ScalarAttributeImpl): 

16 def __init__(self, *args, **kwargs): 

17 """ 

18 The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22, 

19 adding a 'default_function' required positional argument before 'dispatch'. 

20 This adjustment ensures compatibility across versions by inserting None for 

21 'default_function' in versions >= 2.0.22. 

22 

23 Arguments received: (class, key, dispatch) 

24 Required by AttributeImpl: (class, key, default_function, dispatch) 

25 Setting None as default_function here. 

26 """ 

27 # Adjust for SQLAlchemy version change 

28 sqlalchemy_version = tuple(map(int, sa.__version__.split('.'))) 

29 if sqlalchemy_version >= (2, 0, 22): 

30 args = (*args[:2], None, *args[2:]) 

31 

32 super().__init__(*args, **kwargs) 

33 

34 def get(self, state, dict_, passive=attributes.PASSIVE_OFF): 

35 if self.key in dict_: 

36 return dict_[self.key] 

37 

38 # Retrieve the session bound to the state in order to perform 

39 # a lazy query for the attribute. 

40 session = _state_session(state) 

41 if session is None: 

42 # State is not bound to a session; we cannot proceed. 

43 return None 

44 

45 # Find class for discriminator. 

46 # TODO: Perhaps optimize with some sort of lookup? 

47 discriminator = self.get_state_discriminator(state) 

48 target_class = _get_class_registry(state.class_).get(discriminator) 

49 

50 if target_class is None: 

51 # Unknown discriminator; return nothing. 

52 return None 

53 

54 id = self.get_state_id(state) 

55 

56 target = session.get(target_class, id) 

57 

58 # Return found (or not found) target. 

59 return target 

60 

61 def get_state_discriminator(self, state): 

62 discriminator = self.parent_token.discriminator 

63 if isinstance(discriminator, hybrid_property): 

64 return getattr(state.obj(), discriminator.__name__) 

65 else: 

66 return state.attrs[discriminator.key].value 

67 

68 def get_state_id(self, state): 

69 # Lookup row with the discriminator and id. 

70 return tuple(state.attrs[id.key].value for id in self.parent_token.id) 

71 

72 def set( 

73 self, 

74 state, 

75 dict_, 

76 initiator, 

77 passive=attributes.PASSIVE_OFF, 

78 check_old=None, 

79 pop=False, 

80 ): 

81 # Set us on the state. 

82 dict_[self.key] = initiator 

83 

84 if initiator is None: 

85 # Nullify relationship args 

86 for id in self.parent_token.id: 

87 dict_[id.key] = None 

88 dict_[self.parent_token.discriminator.key] = None 

89 else: 

90 # Get the primary key of the initiator and ensure we 

91 # can support this assignment. 

92 class_ = type(initiator) 

93 mapper = class_mapper(class_) 

94 

95 pk = mapper.identity_key_from_instance(initiator)[1] 

96 

97 # Set the identifier and the discriminator. 

98 discriminator = class_.__name__ 

99 

100 for index, id in enumerate(self.parent_token.id): 

101 dict_[id.key] = pk[index] 

102 dict_[self.parent_token.discriminator.key] = discriminator 

103 

104 

105class GenericRelationshipProperty(MapperProperty): 

106 """A generic form of the relationship property. 

107 

108 Creates a 1 to many relationship between the parent model 

109 and any other models using a discriminator (the table name). 

110 

111 :param discriminator: 

112 Field to discriminate which model we are referring to. 

113 :param id: 

114 Field to point to the model we are referring to. 

115 """ 

116 

117 def __init__(self, discriminator, id, doc=None): 

118 super().__init__() 

119 self._discriminator_col = discriminator 

120 self._id_cols = id 

121 self._id = None 

122 self._discriminator = None 

123 self.doc = doc 

124 

125 set_creation_order(self) 

126 

127 def _column_to_property(self, column): 

128 if isinstance(column, hybrid_property): 

129 attr_key = column.__name__ 

130 for key, attr in self.parent.all_orm_descriptors.items(): 

131 if key == attr_key: 

132 return attr 

133 else: 

134 for attr in self.parent.attrs.values(): 

135 if isinstance(attr, ColumnProperty): 

136 if attr.columns[0].name == column.name: 

137 return attr 

138 

139 def init(self): 

140 def convert_strings(column): 

141 if isinstance(column, str): 

142 return self.parent.columns[column] 

143 return column 

144 

145 self._discriminator_col = convert_strings(self._discriminator_col) 

146 self._id_cols = convert_strings(self._id_cols) 

147 

148 if isinstance(self._id_cols, Iterable): 

149 self._id_cols = list(map(convert_strings, self._id_cols)) 

150 else: 

151 self._id_cols = [self._id_cols] 

152 

153 self.discriminator = self._column_to_property(self._discriminator_col) 

154 

155 if self.discriminator is None: 

156 raise ImproperlyConfigured('Could not find discriminator descriptor.') 

157 

158 self.id = list(map(self._column_to_property, self._id_cols)) 

159 

160 class Comparator(PropComparator): 

161 def __init__(self, prop, parentmapper): 

162 self.property = prop 

163 self._parententity = parentmapper 

164 

165 def __eq__(self, other): 

166 discriminator = type(other).__name__ 

167 q = self.property._discriminator_col == discriminator 

168 other_id = identity(other) 

169 for index, id in enumerate(self.property._id_cols): 

170 q &= id == other_id[index] 

171 return q 

172 

173 def __ne__(self, other): 

174 return ~(self == other) 

175 

176 def is_type(self, other): 

177 mapper = sa.inspect(other) 

178 # Iterate through the weak sequence in order to get the actual 

179 # mappers 

180 class_names = [other.__name__] 

181 class_names.extend( 

182 [submapper.class_.__name__ for submapper in mapper._inheriting_mappers] 

183 ) 

184 

185 return self.property._discriminator_col.in_(class_names) 

186 

187 def instrument_class(self, mapper): 

188 attributes.register_attribute( 

189 mapper.class_, 

190 self.key, 

191 comparator=self.Comparator(self, mapper), 

192 parententity=mapper, 

193 doc=self.doc, 

194 impl_class=GenericAttributeImpl, 

195 parent_token=self, 

196 ) 

197 

198 

199def generic_relationship(*args, **kwargs): 

200 return GenericRelationshipProperty(*args, **kwargs)