Utility bot for rhythm game D4DJ. (Note that some dependencies are not public)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
miyu-bot/miyu_bot/commands/common/fuzzy_matching.py

210 lines
7.5 KiB

4 years ago
import logging
import math
4 years ago
import re
import timeit
import datetime
4 years ago
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Optional, Iterable
4 years ago
import pykakasi
class FuzzyFilteredMap:
def __init__(self, filter_function=None, matcher=None, additive_only_filter=True):
self.filter = filter_function or (lambda n: True)
4 years ago
self.matcher = matcher or FuzzyMatcher()
self._values = {}
4 years ago
self.max_length = 0
self.logger = logging.getLogger(__name__)
self._stale = True
self.additive_only_filter = additive_only_filter
@property
def filtered_items(self):
if not self.additive_only_filter:
return [item for item in self._values.items() if self.filter(item[1])]
if self._needs_update:
self._update_items()
return self._filtered_items
@property
def _needs_update(self):
return self._stale or any(self.filter(item[1]) for item in self._filtered_out_items)
def _update_items(self):
self._filtered_items = [item for item in self._values.items() if self.filter(item[1])]
self._filtered_out_items = [item for item in self._values.items() if not self.filter(item[1])]
self._stale = False
4 years ago
4 years ago
def values(self):
return FuzzyDictValuesView(self)
4 years ago
def has_exact(self, key):
return romanize(key) in self._values
4 years ago
def __delitem__(self, key):
k = romanize(key)
self._values.__delitem__(k)
self._stale = True
4 years ago
4 years ago
def __setitem__(self, key, value):
k = romanize(key)
4 years ago
self._values[k] = value
self.max_length = max(self.max_length, math.ceil(len(k) * 1.1))
self.matcher.set_max_length(self.max_length)
self._stale = True
4 years ago
def __getitem__(self, key):
start_time = timeit.default_timer()
key = romanize(key)
if len(key) > self.max_length:
4 years ago
self.logger.debug(f'Rejected key "{key}" due to length.')
return None
try:
matcher = self.matcher
result = min((score, item) for score, item in
((matcher.score(key, item[0]), item) for item in self.filtered_items)
if score <= 0)[1][1]
self.logger.info(f'Found key "{key}" in time {timeit.default_timer() - start_time}.')
return result
except ValueError:
self.logger.info(f'Found no results for key "{key}" in time {timeit.default_timer() - start_time}.')
4 years ago
return None
def get_sorted(self, key: str):
start_time = timeit.default_timer()
if len(key) > self.max_length:
self.logger.debug(f'Rejected key "{key}" due to length.')
return []
key = romanize(key)
values = [item[1] for score, item in
sorted(
(self.matcher.score(key, item[0]), item) for item in self.filtered_items)
if score <= 0]
self.logger.info(f'Searched key "{key}" in time {timeit.default_timer() - start_time}.')
return values
4 years ago
class FuzzyDictValuesView:
def __init__(self, map: FuzzyFilteredMap):
self._map = map
def __contains__(self, item):
return item in self._map._values.values() and self._map.filter(item)
def __iter__(self):
yield from (v for _, v in self._map.filtered_items)
4 years ago
@dataclass
class FuzzyMatchConfig:
base_score: float = 0.0
insertion_weight: float = 0.001
deletion_weight: float = 1.0
default_substitution_weight: float = 1.0
match_weight: float = -0.2
special_substitution_weights: Dict[Tuple[str, str], float] = field(default_factory=lambda: {
('v', 'b'): 0.0,
('l', 'r'): 0.0,
('c', 'k'): 0.0,
4 years ago
('y', 'i'): 0.4,
4 years ago
})
word_match_weight: float = -0.2
whole_match_weight: float = -0.25
4 years ago
acronym_match_weight: float = -0.3
4 years ago
4 years ago
class FuzzyMatcher:
def __init__(self, config: FuzzyMatchConfig = None):
self.config = config or FuzzyMatchConfig()
self.array: Optional[List[List[float]]] = None
def set_max_length(self, length: int):
if not length:
self.array = None
else:
self.array = [[0] * (length + 1) for _ in range(length + 1)]
for i in range(length + 1):
self.array[i][0] = i * self.config.deletion_weight
self.array[0][i] = i * self.config.insertion_weight
def score(self, source: str, target: str, threshold=0.0):
4 years ago
if not target:
return 1
4 years ago
l_src = len(source)
l_tgt = len(target)
a = self.array
4 years ago
config = self.config
base_score = config.base_score
insertion_weight = config.insertion_weight
deletion_weight = config.deletion_weight
default_substitution_weight = config.default_substitution_weight
match_weight = config.match_weight
special_substitution_weights = config.special_substitution_weights
word_match_weight = config.word_match_weight
whole_match_weight = config.whole_match_weight
acronym_match_weight = config.acronym_match_weight
if not a:
a = [[0] * (l_tgt + 1) for _ in range(l_src + 1)]
for i in range(l_src + 1):
a[i][0] = i
for i in range(l_tgt + 1):
a[0][i] = i * insertion_weight
4 years ago
words = target.split()
word_bonus = min(word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words),
word_match_weight * max(sum(a == b for a, b in
zip(source, w[0] + strip_vowels(w[1:]))) for w in
words),
whole_match_weight * sum(a == b for a, b in zip(strip_spaces(source), strip_spaces(target))),
acronym_match_weight * sum(
4 years ago
a == b for a, b in zip(source, ''.join(w[0] for w in words))))
threshold -= word_bonus + base_score
4 years ago
for i_src in range(1, l_src + 1):
for i_tgt in range(1, l_tgt + 1):
a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + ((special_substitution_weights.get(
(source[i_src - 1], target[i_tgt - 1]),
default_substitution_weight
)) if source[i_src - 1] != target[i_tgt - 1] else match_weight),
a[i_src - 1][i_tgt] + deletion_weight,
a[i_src][i_tgt - 1] + insertion_weight)
# there are l_scr - i_src source chars remaining
# each match removes the insertion weight then adds the match weight
# this is the max difference that can make
max_additional_score = (l_src - i_src) * (match_weight - insertion_weight)
if ((a[i_src][l_tgt] + max_additional_score) > threshold and
(a[i_src][l_tgt - 1] + max_additional_score) > threshold):
return 1
return a[l_src][l_tgt] + word_bonus + base_score
4 years ago
def strip_spaces(s):
return re.sub(' ', '', s)
def strip_vowels(s):
return re.sub('[aeoiu]', '', s)
4 years ago
def romanize(s: str) -> str:
4 years ago
kks = pykakasi.kakasi()
s = str(s)
4 years ago
s = re.sub('[\']', '', s)
4 years ago
s = re.sub('[・]', ' ', s)
4 years ago
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)
4 years ago
s = re.sub('[0-9]+', lambda ele: f' {ele[0]} ', s)
4 years ago
s = ' '.join(c['hepburn'].strip().lower() for c in kks.convert(s))
s = re.sub(r'[^a-zA-Z0-9_ ]+', '', s)
4 years ago
return ' '.join(s.split())