1from __future__ import annotations
2
3# built-in
4from collections import defaultdict
5from itertools import groupby, zip_longest
6from typing import Any, Iterator, Sequence, TypeVar
7
8# app
9from .base import Base as _Base, BaseSimilarity as _BaseSimilarity
10
11
12try:
13 # external
14 import numpy
15except ImportError:
16 numpy = None # type: ignore[assignment]
17
18
19__all__ = [
20 'MRA', 'Editex',
21 'mra', 'editex',
22]
23T = TypeVar('T')
24
25
26class MRA(_BaseSimilarity):
27 """Western Airlines Surname Match Rating Algorithm comparison rating
28 https://en.wikipedia.org/wiki/Match_rating_approach
29 https://github.com/Yomguithereal/talisman/blob/master/src/metrics/mra.js
30 """
31
32 def maximum(self, *sequences: str) -> int:
33 sequences = [list(self._calc_mra(s)) for s in sequences]
34 return max(map(len, sequences))
35
36 def _calc_mra(self, word: str) -> str:
37 if not word:
38 return word
39 word = word.upper()
40 word = word[0] + ''.join(c for c in word[1:] if c not in 'AEIOU')
41 # remove repeats like an UNIX uniq
42 word = ''.join(char for char, _ in groupby(word))
43 if len(word) > 6:
44 return word[:3] + word[-3:]
45 return word
46
47 def __call__(self, *sequences: str) -> int:
48 if not all(sequences):
49 return 0
50 sequences = [list(self._calc_mra(s)) for s in sequences]
51 lengths = list(map(len, sequences))
52 count = len(lengths)
53 max_length = max(lengths)
54 if abs(max_length - min(lengths)) > count:
55 return 0
56
57 for _ in range(count):
58 new_sequences = []
59 minlen = min(lengths)
60 for chars in zip(*sequences):
61 if not self._ident(*chars):
62 new_sequences.append(chars)
63 new_sequences = map(list, zip(*new_sequences))
64 # update sequences
65 ss: Iterator[tuple[Any, Any]]
66 ss = zip_longest(new_sequences, sequences, fillvalue=list())
67 sequences = [s1 + s2[minlen:] for s1, s2 in ss]
68 # update lengths
69 lengths = list(map(len, sequences))
70
71 if not lengths:
72 return max_length
73 return max_length - max(lengths)
74
75
76class Editex(_Base):
77 """
78 https://anhaidgroup.github.io/py_stringmatching/v0.3.x/Editex.html
79 http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3856&rep=rep1&type=pdf
80 http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.18.2138&rep=rep1&type=pdf
81 https://github.com/chrislit/blob/master/abydos/distance/_editex.py
82 https://habr.com/ru/post/331174/ (RUS)
83 """
84 groups: tuple[frozenset[str], ...] = (
85 frozenset('AEIOUY'),
86 frozenset('BP'),
87 frozenset('CKQ'),
88 frozenset('DT'),
89 frozenset('LR'),
90 frozenset('MN'),
91 frozenset('GJ'),
92 frozenset('FPV'),
93 frozenset('SXZ'),
94 frozenset('CSZ'),
95 )
96 ungrouped = frozenset('HW') # all letters in alphabet that not presented in `grouped`
97
98 def __init__(
99 self,
100 local: bool = False,
101 match_cost: int = 0,
102 group_cost: int = 1,
103 mismatch_cost: int = 2,
104 groups: tuple[frozenset[str], ...] = None,
105 ungrouped: frozenset[str] = None,
106 external: bool = True,
107 ) -> None:
108 # Ensure that match_cost <= group_cost <= mismatch_cost
109 self.match_cost = match_cost
110 self.group_cost = max(group_cost, self.match_cost)
111 self.mismatch_cost = max(mismatch_cost, self.group_cost)
112 self.local = local
113 self.external = external
114
115 if groups is not None:
116 if ungrouped is None:
117 raise ValueError('`ungrouped` argument required with `groups`')
118 self.groups = groups
119 self.ungrouped = ungrouped
120 self.grouped = frozenset.union(*self.groups)
121
122 def maximum(self, *sequences: Sequence) -> int:
123 return max(map(len, sequences)) * self.mismatch_cost
124
125 def r_cost(self, *elements: str) -> int:
126 if self._ident(*elements):
127 return self.match_cost
128 if any(map(lambda x: x not in self.grouped, elements)):
129 return self.mismatch_cost
130 for group in self.groups:
131 if all(map(lambda x: x in group, elements)):
132 return self.group_cost
133 return self.mismatch_cost
134
135 def d_cost(self, *elements: str) -> int:
136 if not self._ident(*elements) and elements[0] in self.ungrouped:
137 return self.group_cost
138 return self.r_cost(*elements)
139
140 def __call__(self, s1: str, s2: str) -> float:
141 result = self.quick_answer(s1, s2)
142 if result is not None:
143 return result
144
145 # must do `upper` before getting length because some one-char lowercase glyphs
146 # are represented as two chars in uppercase.
147 # This might result in a distance that is greater than the maximum
148 # input sequence length, though, so we save that maximum first.
149 max_length = self.maximum(s1, s2)
150 s1 = ' ' + s1.upper()
151 s2 = ' ' + s2.upper()
152 len_s1 = len(s1) - 1
153 len_s2 = len(s2) - 1
154 d_mat: Any
155 if numpy:
156 d_mat = numpy.zeros((len_s1 + 1, len_s2 + 1), dtype=int)
157 else:
158 d_mat = defaultdict(lambda: defaultdict(int))
159
160 if not self.local:
161 for i in range(1, len_s1 + 1):
162 d_mat[i][0] = d_mat[i - 1][0] + self.d_cost(s1[i - 1], s1[i])
163 for j in range(1, len_s2 + 1):
164 d_mat[0][j] = d_mat[0][j - 1] + self.d_cost(s2[j - 1], s2[j])
165
166 for i, (cs1_prev, cs1_curr) in enumerate(zip(s1, s1[1:]), start=1):
167 for j, (cs2_prev, cs2_curr) in enumerate(zip(s2, s2[1:]), start=1):
168 d_mat[i][j] = min(
169 d_mat[i - 1][j] + self.d_cost(cs1_prev, cs1_curr),
170 d_mat[i][j - 1] + self.d_cost(cs2_prev, cs2_curr),
171 d_mat[i - 1][j - 1] + self.r_cost(cs1_curr, cs2_curr),
172 )
173
174 distance = d_mat[len_s1][len_s2]
175 return min(distance, max_length)
176
177
178mra = MRA()
179editex = Editex()