Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy_utils/functions/foreign_keys.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

88 statements  

1from collections import defaultdict 

2from itertools import groupby 

3 

4import sqlalchemy as sa 

5from sqlalchemy.exc import NoInspectionAvailable 

6from sqlalchemy.orm import object_session 

7from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table 

8 

9from ..query_chain import QueryChain 

10from .database import has_index 

11from .orm import _get_class_registry, get_column_key, get_mapper, get_tables 

12 

13 

14def get_foreign_key_values(fk, obj): 

15 mapper = get_mapper(obj) 

16 return { 

17 fk.constraint.columns.values()[index]: getattr(obj, element.column.key) 

18 if hasattr(obj, element.column.key) 

19 else getattr(obj, mapper.get_property_by_column(element.column).key) 

20 for index, element in enumerate(fk.constraint.elements) 

21 } 

22 

23 

24def group_foreign_keys(foreign_keys): 

25 """ 

26 Return a groupby iterator that groups given foreign keys by table. 

27 

28 :param foreign_keys: a sequence of foreign keys 

29 

30 

31 :: 

32 

33 foreign_keys = get_referencing_foreign_keys(User) 

34 

35 for table, fks in group_foreign_keys(foreign_keys): 

36 # do something 

37 pass 

38 

39 

40 .. seealso:: :func:`get_referencing_foreign_keys` 

41 

42 .. versionadded: 0.26.1 

43 """ 

44 foreign_keys = sorted(foreign_keys, key=lambda key: key.constraint.table.name) 

45 return groupby(foreign_keys, lambda key: key.constraint.table) 

46 

47 

48def get_referencing_foreign_keys(mixed): 

49 """ 

50 Returns referencing foreign keys for given Table object or declarative 

51 class. 

52 

53 :param mixed: 

54 SA Table object or SA declarative class 

55 

56 :: 

57 

58 get_referencing_foreign_keys(User) # set([ForeignKey('user.id')]) 

59 

60 get_referencing_foreign_keys(User.__table__) 

61 

62 

63 This function also understands inheritance. This means it returns 

64 all foreign keys that reference any table in the class inheritance tree. 

65 

66 Let's say you have three classes which use joined table inheritance, 

67 namely TextItem, Article and BlogPost with Article and BlogPost inheriting 

68 TextItem. 

69 

70 :: 

71 

72 # This will check all foreign keys that reference either article table 

73 # or textitem table. 

74 get_referencing_foreign_keys(Article) 

75 

76 .. seealso:: :func:`get_tables` 

77 """ 

78 if isinstance(mixed, sa.Table): 

79 tables = [mixed] 

80 else: 

81 tables = get_tables(mixed) 

82 

83 referencing_foreign_keys = set() 

84 

85 for table in mixed.metadata.tables.values(): 

86 if table not in tables: 

87 for constraint in table.constraints: 

88 if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): 

89 for fk in constraint.elements: 

90 if any(fk.references(t) for t in tables): 

91 referencing_foreign_keys.add(fk) 

92 return referencing_foreign_keys 

93 

94 

95def merge_references(from_, to, foreign_keys=None): 

96 """ 

97 Merge the references of an entity into another entity. 

98 

99 Consider the following models:: 

100 

101 class User(self.Base): 

102 __tablename__ = 'user' 

103 id = sa.Column(sa.Integer, primary_key=True) 

104 name = sa.Column(sa.String(255)) 

105 

106 def __repr__(self): 

107 return 'User(name=%r)' % self.name 

108 

109 class BlogPost(self.Base): 

110 __tablename__ = 'blog_post' 

111 id = sa.Column(sa.Integer, primary_key=True) 

112 title = sa.Column(sa.String(255)) 

113 author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) 

114 

115 author = sa.orm.relationship(User) 

116 

117 

118 Now lets add some data:: 

119 

120 john = self.User(name='John') 

121 jack = self.User(name='Jack') 

122 post = self.BlogPost(title='Some title', author=john) 

123 post2 = self.BlogPost(title='Other title', author=jack) 

124 self.session.add_all([ 

125 john, 

126 jack, 

127 post, 

128 post2 

129 ]) 

130 self.session.commit() 

131 

132 

133 If we wanted to merge all John's references to Jack it would be as easy as 

134 :: 

135 

136 merge_references(john, jack) 

137 self.session.commit() 

138 

139 post.author # User(name='Jack') 

140 post2.author # User(name='Jack') 

141 

142 

143 

144 :param from_: an entity to merge into another entity 

145 :param to: an entity to merge another entity into 

146 :param foreign_keys: A sequence of foreign keys. By default this is None 

147 indicating all referencing foreign keys should be used. 

148 

149 .. seealso: :func:`dependent_objects` 

150 

151 .. versionadded: 0.26.1 

152 

153 .. versionchanged: 0.40.0 

154 

155 Removed possibility for old-style synchronize_session merging. Only 

156 SQL based merging supported for now. 

157 """ 

158 if from_.__tablename__ != to.__tablename__: 

159 raise TypeError('The tables of given arguments do not match.') 

160 

161 session = object_session(from_) 

162 foreign_keys = get_referencing_foreign_keys(from_) 

163 

164 for fk in foreign_keys: 

165 old_values = get_foreign_key_values(fk, from_) 

166 new_values = get_foreign_key_values(fk, to) 

167 criteria = ( 

168 getattr(fk.constraint.table.c, key.key) == value 

169 for key, value in old_values.items() 

170 ) 

171 query = ( 

172 fk.constraint.table.update() 

173 .where(sa.and_(*criteria)) 

174 .values({key.key: value for key, value in new_values.items()}) 

175 ) 

176 session.execute(query) 

177 

178 

179def dependent_objects(obj, foreign_keys=None): 

180 """ 

181 Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates 

182 through all dependent objects for given SQLAlchemy object. 

183 

184 Consider a User object is referenced in various articles and also in 

185 various orders. Getting all these dependent objects is as easy as:: 

186 

187 from sqlalchemy_utils import dependent_objects 

188 

189 

190 dependent_objects(user) 

191 

192 

193 If you expect an object to have lots of dependent_objects it might be good 

194 to limit the results:: 

195 

196 

197 dependent_objects(user).limit(5) 

198 

199 

200 

201 The common use case is checking for all restrict dependent objects before 

202 deleting parent object and inform the user if there are dependent objects 

203 with ondelete='RESTRICT' foreign keys. If this kind of checking is not used 

204 it will lead to nasty IntegrityErrors being raised. 

205 

206 In the following example we delete given user if it doesn't have any 

207 foreign key restricted dependent objects:: 

208 

209 

210 from sqlalchemy_utils import get_referencing_foreign_keys 

211 

212 

213 user = session.query(User).get(some_user_id) 

214 

215 

216 deps = list( 

217 dependent_objects( 

218 user, 

219 ( 

220 fk for fk in get_referencing_foreign_keys(User) 

221 # On most databases RESTRICT is the default mode hence we 

222 # check for None values also 

223 if fk.ondelete == 'RESTRICT' or fk.ondelete is None 

224 ) 

225 ).limit(5) 

226 ) 

227 

228 if deps: 

229 # Do something to inform the user 

230 pass 

231 else: 

232 session.delete(user) 

233 

234 

235 :param obj: SQLAlchemy declarative model object 

236 :param foreign_keys: 

237 A sequence of foreign keys to use for searching the dependent_objects 

238 for given object. By default this is None, indicating that all foreign 

239 keys referencing the object will be used. 

240 

241 .. note:: 

242 This function does not support exotic mappers that use multiple tables 

243 

244 .. seealso:: :func:`get_referencing_foreign_keys` 

245 .. seealso:: :func:`merge_references` 

246 

247 .. versionadded: 0.26.0 

248 """ 

249 if foreign_keys is None: 

250 foreign_keys = get_referencing_foreign_keys(obj) 

251 

252 session = object_session(obj) 

253 

254 chain = QueryChain([]) 

255 classes = _get_class_registry(obj.__class__) 

256 

257 for table, keys in group_foreign_keys(foreign_keys): 

258 keys = list(keys) 

259 for class_ in classes.values(): 

260 try: 

261 mapper = sa.inspect(class_) 

262 except NoInspectionAvailable: 

263 continue 

264 parent_mapper = mapper.inherits 

265 if table in mapper.tables and not ( 

266 parent_mapper and table in parent_mapper.tables 

267 ): 

268 query = session.query(class_).filter( 

269 sa.or_(*_get_criteria(keys, class_, obj)) 

270 ) 

271 chain.queries.append(query) 

272 return chain 

273 

274 

275def _get_criteria(keys, class_, obj): 

276 criteria = [] 

277 visited_constraints = [] 

278 for key in keys: 

279 if key.constraint in visited_constraints: 

280 continue 

281 visited_constraints.append(key.constraint) 

282 

283 subcriteria = [] 

284 for index, column in enumerate(key.constraint.columns): 

285 foreign_column = key.constraint.elements[index].column 

286 subcriteria.append( 

287 getattr(class_, get_column_key(class_, column)) 

288 == getattr( 

289 obj, 

290 sa.inspect(type(obj)).get_property_by_column(foreign_column).key, 

291 ) 

292 ) 

293 criteria.append(sa.and_(*subcriteria)) 

294 return criteria 

295 

296 

297def non_indexed_foreign_keys(metadata, engine=None): 

298 """ 

299 Finds all non indexed foreign keys from all tables of given MetaData. 

300 

301 Very useful for optimizing postgresql database and finding out which 

302 foreign keys need indexes. 

303 

304 :param metadata: MetaData object to inspect tables from 

305 """ 

306 reflected_metadata = MetaData() 

307 

308 bind = getattr(metadata, 'bind', None) 

309 if bind is None and engine is None: 

310 raise Exception( 

311 'Either pass a metadata object with bind or ' 

312 'pass engine as a second parameter' 

313 ) 

314 

315 constraints = defaultdict(list) 

316 

317 for table_name in metadata.tables.keys(): 

318 table = Table(table_name, reflected_metadata, autoload_with=bind or engine) 

319 

320 for constraint in table.constraints: 

321 if not isinstance(constraint, ForeignKeyConstraint): 

322 continue 

323 

324 if not has_index(constraint): 

325 constraints[table.name].append(constraint) 

326 

327 return dict(constraints) 

328 

329 

330def get_fk_constraint_for_columns(table, *columns): 

331 for constraint in table.constraints: 

332 if list(constraint.columns.values()) == list(columns): 

333 return constraint