1import sqlalchemy as sa
2from sqlalchemy.orm.attributes import InstrumentedAttribute
3from sqlalchemy.util.langhelpers import symbol
4
5from .utils import str_coercible
6
7
8@str_coercible
9class Path:
10 def __init__(self, path, separator='.'):
11 if isinstance(path, Path):
12 self.path = path.path
13 else:
14 self.path = path
15 self.separator = separator
16
17 @property
18 def parts(self):
19 return self.path.split(self.separator)
20
21 def __iter__(self):
22 yield from self.parts
23
24 def __len__(self):
25 return len(self.parts)
26
27 def __repr__(self):
28 return f"{self.__class__.__name__}('{self.path}')"
29
30 def index(self, element):
31 return self.parts.index(element)
32
33 def __getitem__(self, slice):
34 result = self.parts[slice]
35 if isinstance(result, list):
36 return self.__class__(self.separator.join(result), separator=self.separator)
37 return result
38
39 def __eq__(self, other):
40 return self.path == other.path and self.separator == other.separator
41
42 def __ne__(self, other):
43 return not (self == other)
44
45 def __unicode__(self):
46 return self.path
47
48
49def get_attr(mixed, attr):
50 if isinstance(mixed, InstrumentedAttribute):
51 return getattr(mixed.property.mapper.class_, attr)
52 else:
53 return getattr(mixed, attr)
54
55
56@str_coercible
57class AttrPath:
58 def __init__(self, class_, path):
59 self.class_ = class_
60 self.path = Path(path)
61 self.parts = []
62 last_attr = class_
63 for value in self.path:
64 last_attr = get_attr(last_attr, value)
65 self.parts.append(last_attr)
66
67 def __iter__(self):
68 yield from self.parts
69
70 def __invert__(self):
71 def get_backref(part):
72 prop = part.property
73 backref = prop.backref or prop.back_populates
74 if backref is None:
75 raise Exception(
76 "Invert failed because property '%s' of class "
77 '%s has no backref.' % (prop.key, prop.parent.class_.__name__)
78 )
79 if isinstance(backref, tuple):
80 return backref[0]
81 else:
82 return backref
83
84 if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
85 class_ = self.parts[-1].class_
86 else:
87 class_ = self.parts[-1].mapper.class_
88
89 return self.__class__(class_, '.'.join(map(get_backref, reversed(self.parts))))
90
91 def index(self, element):
92 for index, el in enumerate(self.parts):
93 if el is element:
94 return index
95
96 @property
97 def direction(self):
98 symbols = [part.property.direction for part in self.parts]
99 if symbol('MANYTOMANY') in symbols:
100 return symbol('MANYTOMANY')
101 elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
102 return symbol('MANYTOMANY')
103 return symbols[0]
104
105 @property
106 def uselist(self):
107 return any(part.property.uselist for part in self.parts)
108
109 def __getitem__(self, slice):
110 result = self.parts[slice]
111 if isinstance(result, list) and result:
112 if result[0] is self.parts[0]:
113 class_ = self.class_
114 else:
115 class_ = result[0].parent.class_
116 return self.__class__(class_, self.path[slice])
117 else:
118 return result
119
120 def __len__(self):
121 return len(self.path)
122
123 def __repr__(self):
124 return '{}({}, {!r})'.format(
125 self.__class__.__name__, self.class_.__name__, self.path.path
126 )
127
128 def __eq__(self, other):
129 return self.path == other.path and self.class_ == other.class_
130
131 def __ne__(self, other):
132 return not (self == other)
133
134 def __unicode__(self):
135 return str(self.path)