import logging import math import re import timeit import datetime from dataclasses import dataclass, field from typing import Dict, Tuple, List, Optional, Iterable import pykakasi class FuzzyFilteredMap: def __init__(self, filter_function=None, matcher=None, additive_only_filter=True): self.filter = filter_function or (lambda n: True) self.matcher = matcher or FuzzyMatcher() self._values = {} self.length_cutoff = 0 self.logger = logging.getLogger(__name__) self._stale = True self.additive_only_filter = additive_only_filter @property def filtered_items(self): if not self.additive_only_filter: return [item for item in self._values.items() if self.filter(item[1])] if self._needs_update: self._update_items() return self._filtered_items @property def _needs_update(self): return self._stale or any(self.filter(item[1]) for item in self._filtered_out_items) def _update_items(self): self._filtered_items = [item for item in self._values.items() if self.filter(item[1])] self._filtered_out_items = [item for item in self._values.items() if not self.filter(item[1])] self._stale = False def values(self): return FuzzyDictValuesView(self) def has_exact(self, key): return romanize(key) in self._values def has_exact_unprocessed(self, key): return key in self._values def __delitem__(self, key): k = romanize(key) self._values.__delitem__(k) self._stale = True def __setitem__(self, key, value): self.set_unprocessed(romanize(key), value) def set_unprocessed(self, key, value): self._values[key] = value new_cutoff = math.ceil(len(key) * 1.1) if new_cutoff > self.length_cutoff: self.length_cutoff = new_cutoff self.matcher.set_max_length(new_cutoff) self._stale = True def __getitem__(self, key): start_time = timeit.default_timer() key = romanize(key) if len(key) > self.length_cutoff: self.logger.debug(f'Rejected key "{key}" due to length.') return None try: matcher = self.matcher result = min((score, item) for score, item in ((matcher.score(key, item[0]), item) for item in self.filtered_items) if score <= 0)[1][1] self.logger.info(f'Found key "{key}" in time {timeit.default_timer() - start_time}.') return result except ValueError: self.logger.info(f'Found no results for key "{key}" in time {timeit.default_timer() - start_time}.') return None def get_sorted(self, key: str): start_time = timeit.default_timer() if len(key) > self.length_cutoff: self.logger.debug(f'Rejected key "{key}" due to length.') return [] key = romanize(key) values = [item[1] for score, item in sorted( (self.matcher.score(key, item[0]), item) for item in self.filtered_items) if score <= 0] self.logger.info(f'Searched key "{key}" in time {timeit.default_timer() - start_time}.') return values class FuzzyDictValuesView: def __init__(self, map: FuzzyFilteredMap): self._map = map def __contains__(self, item): return item in self._map._values.values() and self._map.filter(item) def __iter__(self): yield from (v for _, v in self._map.filtered_items) @dataclass class FuzzyMatchConfig: base_score: float = 0.0 insertion_weight: float = 0.001 deletion_weight: float = 1.0 default_substitution_weight: float = 1.0 match_weight: float = -0.2 special_substitution_weights: Dict[Tuple[str, str], float] = field(default_factory=lambda: { ('v', 'b'): 0.0, ('l', 'r'): 0.0, ('c', 'k'): 0.0, ('y', 'i'): 0.4, }) word_match_weight: float = -0.2 whole_match_weight: float = -0.25 acronym_match_weight: float = -0.3 class FuzzyMatcher: def __init__(self, config: FuzzyMatchConfig = None): self.config = config or FuzzyMatchConfig() self.array: Optional[List[List[float]]] = None def set_max_length(self, length: int): if not length: self.array = None else: self.array = [[0] * (length + 1) for _ in range(length + 1)] for i in range(length + 1): self.array[i][0] = i * self.config.deletion_weight self.array[0][i] = i * self.config.insertion_weight def score(self, source: str, target: str, threshold=0.0): if not target: return 1 l_src = len(source) l_tgt = len(target) a = self.array config = self.config base_score = config.base_score insertion_weight = config.insertion_weight deletion_weight = config.deletion_weight default_substitution_weight = config.default_substitution_weight match_weight = config.match_weight special_substitution_weights = config.special_substitution_weights word_match_weight = config.word_match_weight whole_match_weight = config.whole_match_weight acronym_match_weight = config.acronym_match_weight if not a: a = [[0] * (l_tgt + 1) for _ in range(l_src + 1)] for i in range(l_src + 1): a[i][0] = i for i in range(l_tgt + 1): a[0][i] = i * insertion_weight words = target.split() word_bonus = min(word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words), word_match_weight * max(sum(a == b for a, b in zip(source, w[0] + strip_vowels(w[1:]))) for w in words), whole_match_weight * sum(a == b for a, b in zip(strip_spaces(source), strip_spaces(target))), acronym_match_weight * sum( a == b for a, b in zip(source, ''.join(w[0] for w in words)))) threshold -= word_bonus + base_score for i_src in range(1, l_src + 1): for i_tgt in range(1, l_tgt + 1): a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + ((special_substitution_weights.get( (source[i_src - 1], target[i_tgt - 1]), default_substitution_weight )) if source[i_src - 1] != target[i_tgt - 1] else match_weight), a[i_src - 1][i_tgt] + deletion_weight, a[i_src][i_tgt - 1] + insertion_weight) # there are l_scr - i_src source chars remaining # each match removes the insertion weight then adds the match weight # this is the max difference that can make max_additional_score = (l_src - i_src) * (match_weight - insertion_weight) if ((a[i_src][l_tgt] + max_additional_score) > threshold and (a[i_src][l_tgt - 1] + max_additional_score) > threshold): return 1 return a[l_src][l_tgt] + word_bonus + base_score def strip_spaces(s): return re.sub(' ', '', s) def strip_vowels(s): return re.sub('[aeoiu]', '', s) _kks = pykakasi.kakasi() def romanize(s: str) -> str: s = str(s) s = re.sub('[\'・]', '', s) s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s) s = re.sub('[0-9]+', lambda ele: f' {ele[0]} ', s) s = ' '.join(c['hepburn'].strip().lower() for c in _kks.convert(s)) s = re.sub(r'[^a-zA-Z0-9_ ]+', '', s) return ' '.join(s.split())