1import re
2
3from ..utils import str_coercible
4
5path_matcher = re.compile(r'^[A-Za-z0-9_-]+(\.[A-Za-z0-9_-]+)*$')
6
7
8@str_coercible
9class Ltree:
10 """
11 Ltree class wraps a valid string label path. It provides various
12 convenience properties and methods.
13
14 ::
15
16 from sqlalchemy_utils import Ltree
17
18 Ltree('1.2.3').path # '1.2.3'
19
20
21 Ltree always validates the given path.
22
23 ::
24
25 Ltree(None) # raises TypeError
26
27 Ltree('..') # raises ValueError
28
29
30 Validator is also available as class method.
31
32 ::
33
34 Ltree.validate('1.2.3')
35 Ltree.validate(None) # raises TypeError
36
37
38 Ltree supports equality operators.
39
40 ::
41
42 Ltree('Countries.Finland') == Ltree('Countries.Finland')
43 Ltree('Countries.Germany') != Ltree('Countries.Finland')
44
45
46 Ltree objects are hashable.
47
48
49 ::
50
51 assert hash(Ltree('Finland')) == hash('Finland')
52
53
54 Ltree objects have length.
55
56 ::
57
58 assert len(Ltree('1.2')) == 2
59 assert len(Ltree('some.one.some.where')) # 4
60
61
62 You can easily find subpath indexes.
63
64 ::
65
66 assert Ltree('1.2.3').index('2.3') == 1
67 assert Ltree('1.2.3.4.5').index('3.4') == 2
68
69
70 Ltree objects can be sliced.
71
72
73 ::
74
75 assert Ltree('1.2.3')[0:2] == Ltree('1.2')
76 assert Ltree('1.2.3')[1:] == Ltree('2.3')
77
78
79 Finding longest common ancestor.
80
81
82 ::
83
84 assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
85 assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
86
87
88 Ltree objects can be concatenated.
89
90 ::
91
92 assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
93 """
94
95 def __init__(self, path_or_ltree):
96 if isinstance(path_or_ltree, Ltree):
97 self.path = path_or_ltree.path
98 elif isinstance(path_or_ltree, str):
99 self.validate(path_or_ltree)
100 self.path = path_or_ltree
101 else:
102 raise TypeError(
103 "Ltree() argument must be a string or an Ltree, not '{}'".format(
104 type(path_or_ltree).__name__
105 )
106 )
107
108 @classmethod
109 def validate(cls, path):
110 if path_matcher.match(path) is None:
111 raise ValueError(f"'{path}' is not a valid ltree path.")
112
113 def __len__(self):
114 return len(self.path.split('.'))
115
116 def index(self, other):
117 subpath = Ltree(other).path.split('.')
118 parts = self.path.split('.')
119 for index, _ in enumerate(parts):
120 if parts[index : len(subpath) + index] == subpath:
121 return index
122 raise ValueError('subpath not found')
123
124 def descendant_of(self, other):
125 """
126 is left argument a descendant of right (or equal)?
127
128 ::
129
130 assert Ltree('1.2.3.4.5').descendant_of('1.2.3')
131 """
132 subpath = self[: len(Ltree(other))]
133 return subpath == other
134
135 def ancestor_of(self, other):
136 """
137 is left argument an ancestor of right (or equal)?
138
139 ::
140
141 assert Ltree('1.2.3').ancestor_of('1.2.3.4.5')
142 """
143 subpath = Ltree(other)[: len(self)]
144 return subpath == self
145
146 def __getitem__(self, key):
147 if isinstance(key, int):
148 return Ltree(self.path.split('.')[key])
149 elif isinstance(key, slice):
150 return Ltree('.'.join(self.path.split('.')[key]))
151 raise TypeError(f'Ltree indices must be integers, not {key.__class__.__name__}')
152
153 def lca(self, *others):
154 """
155 Lowest common ancestor, i.e., longest common prefix of paths
156
157 ::
158
159 assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
160 """
161 other_parts = [Ltree(other).path.split('.') for other in others]
162 parts = self.path.split('.')
163 for index, element in enumerate(parts):
164 if any(
165 other[index] != element
166 or len(other) <= index + 1
167 or len(parts) == index + 1
168 for other in other_parts
169 ):
170 if index == 0:
171 return None
172 return Ltree('.'.join(parts[0:index]))
173
174 def __add__(self, other):
175 return Ltree(self.path + '.' + Ltree(other).path)
176
177 def __radd__(self, other):
178 return Ltree(other) + self
179
180 def __eq__(self, other):
181 if isinstance(other, Ltree):
182 return self.path == other.path
183 elif isinstance(other, str):
184 return self.path == other
185 else:
186 return NotImplemented
187
188 def __hash__(self):
189 return hash(self.path)
190
191 def __ne__(self, other):
192 return not (self == other)
193
194 def __repr__(self):
195 return f'{self.__class__.__name__}({self.path!r})'
196
197 def __unicode__(self):
198 return self.path
199
200 def __contains__(self, label):
201 return label in self.path.split('.')
202
203 def __gt__(self, other):
204 return self.path > other.path
205
206 def __lt__(self, other):
207 return self.path < other.path
208
209 def __ge__(self, other):
210 return self.path >= other.path
211
212 def __le__(self, other):
213 return self.path <= other.path