diff --git a/miyu_bot/commands/common/fuzzy_matching.py b/miyu_bot/commands/common/fuzzy_matching.py index f77092d..a9eac6e 100644 --- a/miyu_bot/commands/common/fuzzy_matching.py +++ b/miyu_bot/commands/common/fuzzy_matching.py @@ -1,7 +1,9 @@ import logging +import math import re +import timeit from dataclasses import dataclass, field -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional, Iterable import pykakasi @@ -24,16 +26,19 @@ class FuzzyMap: def __setitem__(self, key, value): k = romanize(key) self._values[k] = value - self.max_length = len(k) + self.max_length = max(self.max_length, math.ceil(len(k) * 1.1)) + self.matcher.set_max_length(self.max_length) def __getitem__(self, key): - if len(key) > self.max_length * 1.1: + start_time = timeit.default_timer() + if len(key) > self.max_length: self.logger.debug(f'Rejected key "{key}" due to length.') return None key = romanize(key) - result = min((k for k, v in self._values.items() if self.filter(v)), key=lambda k: self.matcher.score(key, k)) - if self.matcher.score(key, result) > 0: + result = self.matcher.closest_match(key, (k for k, v in self._values.items() if self.filter(v))) + if not result: return None + self.logger.info(f'Found key "{key}" in time {timeit.default_timer() - start_time}.') return self._values[result] @@ -55,56 +60,86 @@ class FuzzyMatchConfig: class FuzzyMatcher: def __init__(self, config: FuzzyMatchConfig = None): self.config = config or FuzzyMatchConfig() + self.array: Optional[List[List[float]]] = None + + def set_max_length(self, length: int): + if not length: + self.array = None + else: + self.array = [[0] * (length + 1) for _ in range(length + 1)] + + for i in range(length + 1): + self.array[i][0] = i * self.config.deletion_weight + self.array[0][i] = i * self.config.insertion_weight + + def closest_match(self, source: str, targets: Iterable[str]) -> Optional[str]: + threshold = 0 + closest = None + for target in targets: + score = self.score(source, target, threshold) + if score <= 0: + threshold = score + closest = target + return closest + + def score(self, source: str, target: str, threshold=0.0): + # target must not be empty - def score(self, source: str, target: str): l_src = len(source) l_tgt = len(target) - a: List[List[float]] = [[0] * (l_tgt + 1) for _ in range(l_src + 1)] - for i in range(l_src + 1): - a[i][0] = i + a = self.array - for i in range(l_tgt + 1): - a[0][i] = i * self.config.insertion_weight + config = self.config + base_score = config.base_score + insertion_weight = config.insertion_weight + deletion_weight = config.deletion_weight + default_substitution_weight = config.default_substitution_weight + match_weight = config.match_weight + special_substitution_weights = config.special_substitution_weights + word_match_weight = config.word_match_weight + acronym_match_weight = config.acronym_match_weight + + if not a: + a = [[0] * (l_tgt + 1) for _ in range(l_src + 1)] + + for i in range(l_src + 1): + a[i][0] = i + + for i in range(l_tgt + 1): + a[0][i] = i * insertion_weight def strip_vowels(s): return re.sub('[aeoiu]', '', s) words = target.split() - word_bonus = min(self.config.word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words), - self.config.word_match_weight * max(sum(a == b for a, b in - zip(source, w[0] + strip_vowels(w[1:]))) for w in - words), - self.config.acronym_match_weight * sum( + word_bonus = min(word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words), + word_match_weight * max(sum(a == b for a, b in + zip(source, w[0] + strip_vowels(w[1:]))) for w in + words), + acronym_match_weight * sum( a == b for a, b in zip(source, ''.join(w[0] for w in words)))) - def sub_weight_at(n, m): - if source[n - 1] != target[m - 1]: - return self.config.special_substitution_weights.get( - (source[n - 1], target[m - 1]), - self.config.default_substitution_weight - ) - else: - return self.config.match_weight + threshold -= word_bonus + base_score for i_src in range(1, l_src + 1): for i_tgt in range(1, l_tgt + 1): - a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + sub_weight_at(i_src, i_tgt), - a[i_src - 1][i_tgt] + self.config.deletion_weight, - a[i_src][i_tgt - 1] + self.config.insertion_weight) - - # there are l_scr - i_src source chars remaining - # each match removes the insertion weight then adds the match weight - # (l_src - i_src) * (self.config.match_weight - self.config.insertion_weight) - # is the max difference that can make - max_additional_score = ((l_src - i_src) * (self.config.match_weight - self.config.insertion_weight) + - word_bonus + self.config.base_score) - if i_tgt == l_tgt and ( - a[i_src][i_tgt] + max_additional_score) > 0 and \ - (a[i_src][i_tgt - 1] + max_additional_score) > 0: - return 1 - - return a[l_src][l_tgt] + word_bonus + self.config.base_score + a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + ((special_substitution_weights.get( + (source[i_src - 1], target[i_tgt - 1]), + default_substitution_weight + )) if source[i_src - 1] != target[i_tgt - 1] else match_weight), + a[i_src - 1][i_tgt] + deletion_weight, + a[i_src][i_tgt - 1] + insertion_weight) + + # there are l_scr - i_src source chars remaining + # each match removes the insertion weight then adds the match weight + # this is the max difference that can make + max_additional_score = (l_src - i_src) * (match_weight - insertion_weight) + if ((a[i_src][l_tgt] + max_additional_score) > threshold and + (a[i_src][l_tgt - 1] + max_additional_score) > threshold): + return 1 + + return a[l_src][l_tgt] + word_bonus + base_score def romanize(s: str) -> str: