1import six
2
3from .utils import build_from_item
4from .nodes import Node
5
6
7class Statement(Node):
8
9 statement = ''
10
11 def __str__(self):
12 return self.statement
13
14
15class SelectStmt(Statement):
16
17 statement = 'SELECT'
18
19 def __init__(self, obj):
20 self.distinct_clause = build_from_item(obj, 'distinctClause')
21 self.into_clause = build_from_item(obj, 'intoClause')
22 self.target_list = build_from_item(obj, 'targetList')
23 self.from_clause = build_from_item(obj, 'fromClause')
24 self.where_clause = build_from_item(obj, 'whereClause')
25 self.group_clause = build_from_item(obj, 'groupClause')
26 self.having_clause = build_from_item(obj, 'havingClause')
27 self.window_clause = build_from_item(obj, 'windowClause')
28
29 self.values_lists = build_from_item(obj, 'valuesLists')
30
31 self.sort_clause = build_from_item(obj, 'sortClause')
32 self.limit_offset = build_from_item(obj, 'limitOffset')
33 self.limit_count = build_from_item(obj, 'limitCount')
34 self.locking_clause = build_from_item(obj, 'lockingClause')
35 self.with_clause = build_from_item(obj, 'withClause')
36
37 self.op = obj.get('op')
38 self.all = obj.get('all')
39 self.larg = build_from_item(obj, 'larg')
40 self.rarg = build_from_item(obj, 'rarg')
41
42 def tables(self):
43 _tables = set()
44 if self.target_list:
45 for item in self.target_list:
46 _tables |= item.tables()
47 if self.from_clause:
48 for item in self.from_clause:
49 _tables |= item.tables()
50 if self.where_clause:
51 _tables |= self.where_clause.tables()
52 if self.with_clause:
53 _tables |= self.with_clause.tables()
54
55 if self.larg:
56 _tables |= self.larg.tables()
57 if self.rarg:
58 _tables |= self.rarg.tables()
59
60 return _tables
61
62
63class InsertStmt(Statement):
64
65 statement = 'INSERT INTO'
66
67 def __init__(self, obj):
68 self.relation = build_from_item(obj, 'relation')
69 self.cols = build_from_item(obj, 'cols')
70 self.select_stmt = build_from_item(obj, 'selectStmt')
71 self.on_conflict_clause = build_from_item(obj, 'onConflictClause')
72 self.returning_list = build_from_item(obj, 'returningList')
73 self.with_clause = build_from_item(obj, 'withClause')
74
75 def tables(self):
76 _tables = self.relation.tables() | self.select_stmt.tables()
77
78 if self.with_clause:
79 _tables |= self.with_clause.tables()
80
81 return _tables
82
83
84class UpdateStmt(Statement):
85
86 statement = 'UPDATE'
87
88 def __init__(self, obj):
89 self.relation = build_from_item(obj, 'relation')
90 self.target_list = build_from_item(obj, 'targetList')
91 self.where_clause = build_from_item(obj, 'whereClause')
92 self.from_clause = build_from_item(obj, 'fromClause')
93 self.returning_list = build_from_item(obj, 'returningList')
94 self.with_clause = build_from_item(obj, 'withClause')
95
96 def tables(self):
97 _tables = self.relation.tables()
98
99 if self.where_clause:
100 _tables |= self.where_clause.tables()
101 if self.from_clause:
102 for item in self.from_clause:
103 _tables |= item.tables()
104 if self.with_clause:
105 _tables |= self.with_clause.tables()
106
107 return _tables
108
109
110class DeleteStmt(Statement):
111
112 statement = 'DELETE FROM'
113
114 def __init__(self, obj):
115 self.relation = build_from_item(obj, 'relation')
116 self.using_clause = build_from_item(obj, 'usingClause')
117 self.where_clause = build_from_item(obj, 'whereClause')
118 self.returning_list = build_from_item(obj, 'returningList')
119 self.with_clause = build_from_item(obj, 'withClause')
120
121 def tables(self):
122 _tables = self.relation.tables()
123
124 if self.using_clause:
125 for item in self.using_clause:
126 _tables |= item.tables()
127 if self.where_clause:
128 _tables |= self.where_clause.tables()
129 if self.with_clause:
130 _tables |= self.with_clause.tables()
131
132 return _tables
133
134
135class WithClause(Node):
136
137 def __init__(self, obj):
138 self.ctes = build_from_item(obj, 'ctes')
139 self.recursive = obj.get('recursive')
140 self.location = obj.get('location')
141
142 def __repr__(self):
143 return '<WithClause (%d)>' % len(self.ctes)
144
145 def __str__(self):
146 s = 'WITH '
147 if self.recursive:
148 s += 'RECURSIVE '
149 s += ', '.join(
150 ['%s AS (%s)' % (name, query)
151 for name, query in six.iteritems(self.ctes)])
152 return s
153
154 def tables(self):
155 _tables = set()
156 for item in self.ctes:
157 _tables |= item.tables()
158 return _tables
159
160
161class CommonTableExpr(Node):
162
163 def __init__(self, obj):
164 self.ctename = obj.get('ctename')
165 self.aliascolnames = build_from_item(obj, 'aliascolnames')
166 self.ctequery = build_from_item(obj, 'ctequery')
167 self.location = obj.get('location')
168 self.cterecursive = obj.get('cterecursive')
169 self.cterefcount = obj.get('cterefcount')
170 self.ctecolnames = build_from_item(obj, 'ctecolnames')
171 self.ctecoltypes = build_from_item(obj, 'ctecoltypes')
172 self.ctecoltypmods = build_from_item(obj, 'ctecoltypmods')
173 self.ctecolcollations = build_from_item(obj, 'ctecolcollations')
174
175 def tables(self):
176 return self.ctequery.tables()
177
178
179class RangeSubselect(Node):
180
181 def __init__(self, obj):
182 self.lateral = obj.get('lateral')
183 self.subquery = build_from_item(obj, 'subquery')
184 self.alias = build_from_item(obj, 'alias')
185
186 def tables(self):
187 return self.subquery.tables()
188
189
190class ResTarget(Node):
191 """
192 Result target.
193
194 In a SELECT target list, 'name' is the column label from an
195 'AS ColumnLabel' clause, or NULL if there was none, and 'val' is the
196 value expression itself. The 'indirection' field is not used.
197
198 INSERT uses ResTarget in its target-column-names list. Here, 'name' is
199 the name of the destination column, 'indirection' stores any subscripts
200 attached to the destination, and 'val' is not used.
201
202 In an UPDATE target list, 'name' is the name of the destination column,
203 'indirection' stores any subscripts attached to the destination, and
204 'val' is the expression to assign.
205 """
206
207 def __init__(self, obj):
208 self.name = obj.get('name')
209 self.indirection = build_from_item(obj, 'indirection')
210 self.val = build_from_item(obj, 'val')
211 self.location = obj.get('location')
212
213 def tables(self):
214 _tables = set()
215 if isinstance(self.val, list):
216 for item in self.val:
217 _tables |= item.tables()
218 elif isinstance(self.val, Node):
219 _tables |= self.val.tables()
220
221 return _tables
222
223
224class ColumnRef(Node):
225
226 def __init__(self, obj):
227 self.fields = build_from_item(obj, 'fields')
228 self.location = obj.get('location')
229
230 def tables(self):
231 return set()
232
233
234class FuncCall(Node):
235
236 def __init__(self, obj):
237 self.funcname = build_from_item(obj, 'funcname')
238 self.args = build_from_item(obj, 'args')
239 self.agg_order = build_from_item(obj, 'agg_order')
240 self.agg_filter = build_from_item(obj, 'agg_filter')
241 self.agg_within_group = obj.get('agg_within_group')
242 self.agg_star = obj.get('agg_star')
243 self.agg_distinct = obj.get('agg_distinct')
244 self.func_variadic = obj.get('func_variadic')
245 self.over = build_from_item(obj, 'over')
246 self.location = obj.get('location')
247
248 def tables(self):
249 _tables = set()
250 if self.args:
251 for item in self.args:
252 _tables |= item.tables()
253 return _tables
254
255
256class AStar(Node):
257
258 def __init__(self, obj):
259 pass
260
261 def tables(self):
262 return set()
263
264
265class AExpr(Node):
266
267 def __init__(self, obj):
268 self.kind = obj.get('kind')
269 self.name = build_from_item(obj, 'name')
270 self.lexpr = build_from_item(obj, 'lexpr')
271 self.rexpr = build_from_item(obj, 'rexpr')
272 self.location = obj.get('location')
273
274 def tables(self):
275 _tables = set()
276
277 if isinstance(self.lexpr, list):
278 for item in self.lexpr:
279 _tables |= item.tables()
280 elif isinstance(self.lexpr, Node):
281 _tables |= self.lexpr.tables()
282
283 if isinstance(self.rexpr, list):
284 for item in self.rexpr:
285 _tables |= item.tables()
286 elif isinstance(self.rexpr, Node):
287 _tables |= self.rexpr.tables()
288
289 return _tables
290
291
292class AConst(Node):
293
294 def __init__(self, obj):
295 self.val = build_from_item(obj, 'val')
296 self.location = obj.get('location')
297
298 def tables(self):
299 return set()
300
301
302class TypeCast(Node):
303
304 def __init__(self, obj):
305 self.arg = build_from_item(obj, 'arg')
306 self.type_name = build_from_item(obj, 'typeName')
307 self.location = obj.get('location')
308
309
310class TypeName(Node):
311
312 def __init__(self, obj):
313 self.names = build_from_item(obj, 'names')
314 self.type_oid = obj.get('typeOid')
315 self.setof = obj.get('setof')
316 self.pct_type = obj.get('pct_type')
317 self.typmods = build_from_item(obj, 'typmods')
318 self.typemod = obj.get('typemod')
319 self.array_bounds = build_from_item(obj, 'arrayBounds')
320 self.location = obj.get('location')
321
322
323class SortBy(Node):
324
325 def __init__(self, obj):
326 self.node = build_from_item(obj, 'node')
327 self.sortby_dir = obj.get('sortby_dir')
328 self.sortby_nulls = obj.get('sortby_nulls')
329 self.use_op = build_from_item(obj, 'useOp')
330 self.location = obj.get('location')
331
332
333class WindowDef(Node):
334
335 def __init__(self, obj):
336 self.name = obj.get('name')
337 self.refname = obj.get('refname')
338 self.partition_clause = build_from_item(obj, 'partitionClause')
339 self.order_clause = build_from_item(obj, 'orderClause')
340 self.frame_options = obj.get('frameOptions')
341 self.start_offset = build_from_item(obj, 'startOffset')
342 self.end_offset = build_from_item(obj, 'endOffset')
343 self.location = obj.get('location')
344
345
346class LockingClause(Node):
347
348 def __init__(self, obj):
349 self.locked_rels = build_from_item(obj, 'lockedRels')
350 self.strength = build_from_item(obj, 'strength')
351 self.wait_policy = obj.get('waitPolicy')
352
353
354class RangeFunction(Node):
355
356 def __init__(self, obj):
357 self.lateral = obj.get('lateral')
358 self.ordinality = obj.get('ordinality')
359 self.is_rowsfrom = obj.get('is_rowsfrom')
360 self.functions = build_from_item(obj, 'functions')
361 self.alias = build_from_item(obj, 'alias')
362 self.coldeflist = build_from_item(obj, 'coldeflist')
363
364
365class AArrayExpr(Node):
366
367 def __init__(self, obj):
368 self.elements = build_from_item(obj, 'elements')
369 self.location = obj.get('location')
370
371
372class AIndices(Node):
373 def __init__(self, obj):
374 self.lidx = build_from_item(obj, 'lidx')
375 self.uidx = build_from_item(obj, 'uidx')
376
377
378class MultiAssignRef(Node):
379
380 def __init__(self, obj):
381 self.source = build_from_item(obj, 'source')
382 self.colno = obj.get('colno')
383 self.ncolumns = obj.get('ncolumns')