1import sqlalchemy as sa
2
3
4def inspect_type(mixed):
5 if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
6 return mixed.property.columns[0].type
7 elif isinstance(mixed, sa.orm.ColumnProperty):
8 return mixed.columns[0].type
9 elif isinstance(mixed, sa.Column):
10 return mixed.type
11
12
13def is_case_insensitive(mixed):
14 try:
15 return isinstance(inspect_type(mixed).comparator, CaseInsensitiveComparator)
16 except AttributeError:
17 try:
18 return issubclass(
19 inspect_type(mixed).comparator_factory, CaseInsensitiveComparator
20 )
21 except AttributeError:
22 return False
23
24
25class CaseInsensitiveComparator(sa.Unicode.Comparator):
26 @classmethod
27 def lowercase_arg(cls, func):
28 def operation(self, other, **kwargs):
29 operator = getattr(sa.Unicode.Comparator, func)
30 if other is None:
31 return operator(self, other, **kwargs)
32 if not is_case_insensitive(other):
33 other = sa.func.lower(other)
34 return operator(self, other, **kwargs)
35
36 return operation
37
38 def in_(self, other):
39 if isinstance(other, list) or isinstance(other, tuple):
40 other = map(sa.func.lower, other)
41 return sa.Unicode.Comparator.in_(self, other)
42
43 def notin_(self, other):
44 if isinstance(other, list) or isinstance(other, tuple):
45 other = map(sa.func.lower, other)
46 return sa.Unicode.Comparator.notin_(self, other)
47
48
49string_operator_funcs = [
50 '__eq__',
51 '__ne__',
52 '__lt__',
53 '__le__',
54 '__gt__',
55 '__ge__',
56 'concat',
57 'contains',
58 'ilike',
59 'like',
60 'notlike',
61 'notilike',
62 'startswith',
63 'endswith',
64]
65
66for func in string_operator_funcs:
67 setattr(
68 CaseInsensitiveComparator, func, CaseInsensitiveComparator.lowercase_arg(func)
69 )