Utility bot for rhythm game D4DJ. (Note that some dependencies are not public)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
miyu-bot/miyu_bot/commands/common/fuzzy_matching.py

217 lines
7.6 KiB

import logging
import math
import re
import timeit
import datetime
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Optional, Iterable
import pykakasi
class FuzzyFilteredMap:
def __init__(self, filter_function=None, matcher=None, additive_only_filter=True):
self.filter = filter_function or (lambda n: True)
self.matcher = matcher or FuzzyMatcher()
self._values = {}
self.length_cutoff = 0
self.logger = logging.getLogger(__name__)
self._stale = True
self.additive_only_filter = additive_only_filter
@property
def filtered_items(self):
if not self.additive_only_filter:
return [item for item in self._values.items() if self.filter(item[1])]
if self._needs_update:
self._update_items()
return self._filtered_items
@property
def _needs_update(self):
return self._stale or any(self.filter(item[1]) for item in self._filtered_out_items)
def _update_items(self):
self._filtered_items = [item for item in self._values.items() if self.filter(item[1])]
self._filtered_out_items = [item for item in self._values.items() if not self.filter(item[1])]
self._stale = False
def values(self):
return FuzzyDictValuesView(self)
def has_exact(self, key):
return romanize(key) in self._values
def has_exact_unprocessed(self, key):
return key in self._values
def __delitem__(self, key):
k = romanize(key)
self._values.__delitem__(k)
self._stale = True
def __setitem__(self, key, value):
self.set_unprocessed(romanize(key), value)
def set_unprocessed(self, key, value):
self._values[key] = value
new_cutoff = math.ceil(len(key) * 1.1)
if new_cutoff > self.length_cutoff:
self.length_cutoff = new_cutoff
self.matcher.set_max_length(new_cutoff)
self._stale = True
def __getitem__(self, key):
start_time = timeit.default_timer()
key = romanize(key)
if len(key) > self.length_cutoff:
self.logger.debug(f'Rejected key "{key}" due to length.')
return None
try:
matcher = self.matcher
result = min((score, item) for score, item in
((matcher.score(key, item[0]), item) for item in self.filtered_items)
if score <= 0)[1][1]
self.logger.info(f'Found key "{key}" in time {timeit.default_timer() - start_time}.')
return result
except ValueError:
self.logger.info(f'Found no results for key "{key}" in time {timeit.default_timer() - start_time}.')
return None
def get_sorted(self, key: str):
start_time = timeit.default_timer()
if len(key) > self.length_cutoff:
self.logger.debug(f'Rejected key "{key}" due to length.')
return []
key = romanize(key)
values = [item[1] for score, item in
sorted(
(self.matcher.score(key, item[0]), item) for item in self.filtered_items)
if score <= 0]
self.logger.info(f'Searched key "{key}" in time {timeit.default_timer() - start_time}.')
return values
class FuzzyDictValuesView:
def __init__(self, map: FuzzyFilteredMap):
self._map = map
def __contains__(self, item):
return item in self._map._values.values() and self._map.filter(item)
def __iter__(self):
yield from (v for _, v in self._map.filtered_items)
@dataclass
class FuzzyMatchConfig:
base_score: float = 0.0
insertion_weight: float = 0.001
deletion_weight: float = 1.0
default_substitution_weight: float = 1.0
match_weight: float = -0.2
special_substitution_weights: Dict[Tuple[str, str], float] = field(default_factory=lambda: {
('v', 'b'): 0.0,
('l', 'r'): 0.0,
('c', 'k'): 0.0,
('y', 'i'): 0.4,
})
word_match_weight: float = -0.2
whole_match_weight: float = -0.25
acronym_match_weight: float = -0.3
class FuzzyMatcher:
def __init__(self, config: FuzzyMatchConfig = None):
self.config = config or FuzzyMatchConfig()
self.array: Optional[List[List[float]]] = None
def set_max_length(self, length: int):
if not length:
self.array = None
else:
self.array = [[0] * (length + 1) for _ in range(length + 1)]
for i in range(length + 1):
self.array[i][0] = i * self.config.deletion_weight
self.array[0][i] = i * self.config.insertion_weight
def score(self, source: str, target: str, threshold=0.0):
if not target:
return 1
l_src = len(source)
l_tgt = len(target)
a = self.array
config = self.config
base_score = config.base_score
insertion_weight = config.insertion_weight
deletion_weight = config.deletion_weight
default_substitution_weight = config.default_substitution_weight
match_weight = config.match_weight
special_substitution_weights = config.special_substitution_weights
word_match_weight = config.word_match_weight
whole_match_weight = config.whole_match_weight
acronym_match_weight = config.acronym_match_weight
if not a:
a = [[0] * (l_tgt + 1) for _ in range(l_src + 1)]
for i in range(l_src + 1):
a[i][0] = i
for i in range(l_tgt + 1):
a[0][i] = i * insertion_weight
words = target.split()
word_bonus = min(word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words),
word_match_weight * max(sum(a == b for a, b in
zip(source, w[0] + strip_vowels(w[1:]))) for w in
words),
whole_match_weight * sum(a == b for a, b in zip(strip_spaces(source), strip_spaces(target))),
acronym_match_weight * sum(
a == b for a, b in zip(source, ''.join(w[0] for w in words))))
threshold -= word_bonus + base_score
for i_src in range(1, l_src + 1):
for i_tgt in range(1, l_tgt + 1):
a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + ((special_substitution_weights.get(
(source[i_src - 1], target[i_tgt - 1]),
default_substitution_weight
)) if source[i_src - 1] != target[i_tgt - 1] else match_weight),
a[i_src - 1][i_tgt] + deletion_weight,
a[i_src][i_tgt - 1] + insertion_weight)
# there are l_scr - i_src source chars remaining
# each match removes the insertion weight then adds the match weight
# this is the max difference that can make
max_additional_score = (l_src - i_src) * (match_weight - insertion_weight)
if ((a[i_src][l_tgt] + max_additional_score) > threshold and
(a[i_src][l_tgt - 1] + max_additional_score) > threshold):
return 1
return a[l_src][l_tgt] + word_bonus + base_score
def strip_spaces(s):
return re.sub(' ', '', s)
def strip_vowels(s):
return re.sub('[aeoiu]', '', s)
_kks = pykakasi.kakasi()
def romanize(s: str) -> str:
s = str(s)
s = re.sub('[\'・]', '', s)
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)
s = re.sub('[0-9]+', lambda ele: f' {ele[0]} ', s)
s = ' '.join(c['hepburn'].strip().lower() for c in _kks.convert(s))
s = re.sub(r'[^a-zA-Z0-9_ ]+', '', s)
return ' '.join(s.split())