1import sqlalchemy as sa
2from sqlalchemy.dialects import postgresql
3from sqlalchemy.ext.compiler import compiles
4from sqlalchemy.sql.expression import ColumnElement, FunctionElement
5from sqlalchemy.sql.functions import GenericFunction
6
7from .functions.orm import quote
8
9
10class array_get(FunctionElement):
11 name = 'array_get'
12
13
14@compiles(array_get)
15def compile_array_get(element, compiler, **kw):
16 args = list(element.clauses)
17 if len(args) != 2:
18 raise Exception(
19 "Function 'array_get' expects two arguments (%d given)." % len(args)
20 )
21
22 if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
23 raise Exception('Second argument should be an integer.')
24 return f'({compiler.process(args[0])})[{sa.text(str(args[1].value + 1))}]'
25
26
27class row_to_json(GenericFunction):
28 name = 'row_to_json'
29 type = postgresql.JSON
30
31
32@compiles(row_to_json, 'postgresql')
33def compile_row_to_json(element, compiler, **kw):
34 return f'{element.name}({compiler.process(element.clauses)})'
35
36
37class json_array_length(GenericFunction):
38 name = 'json_array_length'
39 type = sa.Integer
40
41
42@compiles(json_array_length, 'postgresql')
43def compile_json_array_length(element, compiler, **kw):
44 return f'{element.name}({compiler.process(element.clauses)})'
45
46
47class Asterisk(ColumnElement):
48 def __init__(self, selectable):
49 self.selectable = selectable
50
51
52@compiles(Asterisk)
53def compile_asterisk(element, compiler, **kw):
54 return '%s.*' % quote(compiler.dialect, element.selectable.name)