1from sqlalchemy import types
2from sqlalchemy.dialects.postgresql import ARRAY
3from sqlalchemy.dialects.postgresql.base import ischema_names, PGTypeCompiler
4from sqlalchemy.sql import expression
5
6from ..primitives import Ltree
7from .scalar_coercible import ScalarCoercible
8
9
10class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible):
11 """Postgresql LtreeType type.
12
13 The LtreeType datatype can be used for representing labels of data stored
14 in hierarchical tree-like structure. For more detailed information please
15 refer to https://www.postgresql.org/docs/current/ltree.html
16
17 ::
18
19 from sqlalchemy_utils import LtreeType, Ltree
20
21
22 class DocumentSection(Base):
23 __tablename__ = 'document_section'
24 id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
25 path = sa.Column(LtreeType)
26
27
28 section = DocumentSection(path=Ltree('Countries.Finland'))
29 session.add(section)
30 session.commit()
31
32 section.path # Ltree('Countries.Finland')
33
34
35 .. note::
36 Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types
37 may require installation of Postgresql ltree extension on the server
38 side. Please visit https://www.postgresql.org/ for details.
39 """
40
41 cache_ok = True
42
43 class comparator_factory(types.Concatenable.Comparator):
44 def ancestor_of(self, other):
45 if isinstance(other, list):
46 return self.op('@>')(expression.cast(other, ARRAY(LtreeType)))
47 else:
48 return self.op('@>')(other)
49
50 def descendant_of(self, other):
51 if isinstance(other, list):
52 return self.op('<@')(expression.cast(other, ARRAY(LtreeType)))
53 else:
54 return self.op('<@')(other)
55
56 def lquery(self, other):
57 if isinstance(other, list):
58 return self.op('?')(expression.cast(other, ARRAY(LQUERY)))
59 else:
60 return self.op('~')(other)
61
62 def ltxtquery(self, other):
63 return self.op('@')(other)
64
65 def bind_processor(self, dialect):
66 def process(value):
67 if value:
68 return value.path
69
70 return process
71
72 def result_processor(self, dialect, coltype):
73 def process(value):
74 return self._coerce(value)
75
76 return process
77
78 def literal_processor(self, dialect):
79 def process(value):
80 value = value.replace("'", "''")
81 return "'%s'" % value
82
83 return process
84
85 __visit_name__ = 'LTREE'
86
87 def _coerce(self, value):
88 if value:
89 return Ltree(value)
90
91
92class LQUERY(types.TypeEngine):
93 """Postresql LQUERY type.
94 See :class:`LTREE` for details.
95 """
96
97 __visit_name__ = 'LQUERY'
98
99
100class LTXTQUERY(types.TypeEngine):
101 """Postresql LTXTQUERY type.
102 See :class:`LTREE` for details.
103 """
104
105 __visit_name__ = 'LTXTQUERY'
106
107
108ischema_names['ltree'] = LtreeType
109ischema_names['lquery'] = LQUERY
110ischema_names['ltxtquery'] = LTXTQUERY
111
112
113def visit_LTREE(self, type_, **kw):
114 return 'LTREE'
115
116
117def visit_LQUERY(self, type_, **kw):
118 return 'LQUERY'
119
120
121def visit_LTXTQUERY(self, type_, **kw):
122 return 'LTXTQUERY'
123
124
125PGTypeCompiler.visit_LTREE = visit_LTREE
126PGTypeCompiler.visit_LQUERY = visit_LQUERY
127PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY