diff --git a/main.py b/main.py index 768218b..3889942 100644 --- a/main.py +++ b/main.py @@ -9,9 +9,11 @@ logging.basicConfig(level=logging.INFO) with open('config.json') as f: bot_token = json.load(f)['token'] + bot = commands.Bot(command_prefix='!', case_insensitive=True) asset_manager = AssetManager('assets') -bot.load_extension('miyu_bot.commands.cogs.chart') +bot.load_extension('miyu_bot.commands.cogs.music') +bot.load_extension('miyu_bot.commands.cogs.utility') @bot.event diff --git a/miyu_bot/commands/cogs/chart.py b/miyu_bot/commands/cogs/music.py similarity index 61% rename from miyu_bot/commands/cogs/chart.py rename to miyu_bot/commands/cogs/music.py index 6c66c39..1bb4455 100644 --- a/miyu_bot/commands/cogs/chart.py +++ b/miyu_bot/commands/cogs/music.py @@ -7,17 +7,17 @@ from d4dj_utils.master.music_master import MusicMaster from discord.ext import commands from main import asset_manager -from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMatcher +from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap -class Charts(commands.Cog): - def __init__(self, bot): +class Music(commands.Cog): + def __init__(self, bot: commands.Bot): self.bot = bot self.logger = logging.getLogger(__name__) self.music = self.get_music() def get_music(self): - music = FuzzyMatcher(lambda m: m.is_released) + music = FuzzyMap(lambda m: m.is_released) for m in asset_manager.music_master.values(): music[f'{m.name} {m.special_unit_name}'] = m return music @@ -37,15 +37,62 @@ class Charts(commands.Cog): 'es': ChartDifficulty.Easy, } - @commands.command() - async def chart(self, ctx, *, arg): - self.logger.info(f'Searching for chart "{arg}".') + @staticmethod + def format_info(info_entries: dict): + return '\n'.join(f'{k}: {v}' for k, v in info_entries.items() if v) - arg = arg.strip() + @commands.command(name='song', + aliases=['music'], + description='Finds the song with the given name.', + help='!song grgr') + async def song(self, ctx: commands.Context, *, arg: str): + self.logger.info(f'Searching for song "{arg}".') - if not arg: - await ctx.send('Argument is empty.') + song: MusicMaster = self.music[arg] + if not song: + msg = f'Failed to find song "{arg}".' + await ctx.send(msg) + self.logger.info(msg) return + self.logger.info(f'Found "{song}" ({romanize(song.name)[1]}).') + + thumb = discord.File(song.jacket_path, filename='jacket.png') + + embed = discord.Embed(title=song.name) + embed.set_thumbnail(url=f'attachment://jacket.png') + + artist_info = { + 'Lyricist': song.lyricist, + 'Composer': song.composer, + 'Arranger': song.arranger, + 'Unit': song.unit.name, + 'Special Unit Name': song.special_unit_name, + } + + music_info = { + 'Category': song.category.name, + 'BPM': song.bpm, + 'Section Trend': song.section_trend.name, + 'Sort Order': song.default_order, + 'Levels': ', '.join(c.display_level for c in song.charts.values()), + 'Release Date': song.start_datetime, + } + + embed.add_field(name='Artist', + value=self.format_info(artist_info), + inline=False) + embed.add_field(name='Info', + value=self.format_info(music_info), + inline=False) + + await ctx.send(files=[thumb], embed=embed) + + @commands.command(name='chart', + aliases=[], + description='Finds the chart with the given name.', + help='!chart grgr\n!chart grgr normal') + async def chart(self, ctx: commands.Context, *, arg: str): + self.logger.info(f'Searching for chart "{arg}".') split_args = arg.split() @@ -72,12 +119,12 @@ class Charts(commands.Cog): thumb = discord.File(song.jacket_path, filename='jacket.png') render = discord.File(chart.image_path, filename='render.png') - embed = discord.Embed(title=song.name) + embed = discord.Embed(title=f'{song.name} [{difficulty.name}]') embed.set_thumbnail(url=f'attachment://jacket.png') embed.set_image(url=f'attachment://render.png') embed.add_field(name='Info', - value=f'Difficulty: {chart.display_level} ({chart.difficulty.name})\n' + value=f'Level: {chart.display_level}\n' f'Unit: {song.special_unit_name or song.unit.name}\n' f'Category: {song.category.name}\n' f'BPM: {song.bpm}', @@ -96,11 +143,11 @@ class Charts(commands.Cog): f'SCR: {round(chart.trends[2] * 100, 2)}%\n' f'EFT: {round(chart.trends[3] * 100, 2)}%\n' f'TEC: {round(chart.trends[4] * 100, 2)}%\n', - inline=True - ) + inline=True) + embed.set_footer(text='1 column = 10 seconds') await ctx.send(files=[thumb, render], embed=embed) def setup(bot): - bot.add_cog(Charts(bot)) + bot.add_cog(Music(bot)) diff --git a/miyu_bot/commands/cogs/utility.py b/miyu_bot/commands/cogs/utility.py new file mode 100644 index 0000000..39decfa --- /dev/null +++ b/miyu_bot/commands/cogs/utility.py @@ -0,0 +1,23 @@ +import logging + +from discord.ext import commands + +from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMatcher + + +class Utility(commands.Cog): + def __init__(self, bot: commands.Bot): + self.bot = bot + self.logger = logging.getLogger(__name__) + + @commands.command(hidden=True) + async def romanize(self, ctx: commands.Context, *, arg: str): + await ctx.send(romanize(arg)) + + @commands.command(hidden=True, ignore_extra=False) + async def similarity_score(self, ctx: commands.Context, source: str, target: str): + await ctx.send(str(FuzzyMatcher().score(romanize(source), romanize(target)))) + + +def setup(bot): + bot.add_cog(Utility(bot)) diff --git a/miyu_bot/commands/common/fuzzy_matching.py b/miyu_bot/commands/common/fuzzy_matching.py index 890d2b4..f77092d 100644 --- a/miyu_bot/commands/common/fuzzy_matching.py +++ b/miyu_bot/commands/common/fuzzy_matching.py @@ -1,76 +1,118 @@ import logging import re -from typing import Tuple +from dataclasses import dataclass, field +from typing import Dict, Tuple, List import pykakasi -class FuzzyMatcher: - def __init__(self, filter, threshold: float = 1): +class FuzzyMap: + def __init__(self, filter=lambda: True, matcher=None): self.filter = filter or (lambda n: True) - self.threshold = threshold - self.values = {} + self.matcher = matcher or FuzzyMatcher() + self._values = {} self.max_length = 0 self.logger = logging.getLogger(__name__) + def values(self): + return (v for v in self._values.values() if self.filter(v)) + + def __delitem__(self, key): + k = romanize(key) + self._values.__delitem__(k) + def __setitem__(self, key, value): k = romanize(key) - self.values[k] = value - self.max_length = len(k[0]) + self._values[k] = value + self.max_length = len(k) def __getitem__(self, key): if len(key) > self.max_length * 1.1: self.logger.debug(f'Rejected key "{key}" due to length.') return None - key, _ = romanize(key) - result = min((k for k, v in self.values.items() if self.filter(v)), - key=lambda v: fuzzy_match_score(key, *v, threshold=self.threshold)) - if fuzzy_match_score(key, *result, threshold=self.threshold) > self.threshold: + key = romanize(key) + result = min((k for k, v in self._values.items() if self.filter(v)), key=lambda k: self.matcher.score(key, k)) + if self.matcher.score(key, result) > 0: return None - return self.values[result] - - -_insertion_weight = 0.001 -_deletion_weight = 1 -_substitution_weight = 1 - + return self._values[result] -def fuzzy_match_score(source: str, target: str, words, threshold: float) -> float: - m = len(source) - n = len(target) - a = [[0] * (n + 1) for _ in range(m + 1)] - for i in range(m + 1): - a[i][0] = i +@dataclass +class FuzzyMatchConfig: + base_score: float = 0.0 + insertion_weight: float = 0.001 + deletion_weight: float = 1.0 + default_substitution_weight: float = 1.0 + match_weight: float = -0.2 + special_substitution_weights: Dict[Tuple[str, str], float] = field(default_factory=lambda: { + ('v', 'b'): 0.0, + ('l', 'r'): 0.0, + }) + word_match_weight: float = -0.2 + acronym_match_weight: float = -0.3 - for i in range(n + 1): - a[0][i] = i * _insertion_weight - def strip_vowels(s): - return re.sub('[aeoiu]', '', s) - - word_match_bonus = 0.1 * max(max(sum(a == b for a, b in zip(source, w)) for w in words), - max(sum(a == b for a, b in - zip(source[0] + strip_vowels(source[1:]), w[0] + strip_vowels(w[1:]))) for w in - words), - sum(a == b for a, b in zip(source, ''.join(w[0] for w in words)))) - - for i in range(1, m + 1): - for j in range(1, n + 1): - a[i][j] = min(a[i - 1][j - 1] + _substitution_weight if source[i - 1] != target[j - 1] else a[i - 1][j - 1], - a[i - 1][j] + _deletion_weight, - a[i][j - 1] + _insertion_weight) - if j == n and (a[i][j] - (m - i) * _insertion_weight - word_match_bonus) > threshold: - return 9999 - - return a[m][n] - word_match_bonus - - -def romanize(s: str) -> Tuple[str, Tuple[str]]: +class FuzzyMatcher: + def __init__(self, config: FuzzyMatchConfig = None): + self.config = config or FuzzyMatchConfig() + + def score(self, source: str, target: str): + l_src = len(source) + l_tgt = len(target) + a: List[List[float]] = [[0] * (l_tgt + 1) for _ in range(l_src + 1)] + + for i in range(l_src + 1): + a[i][0] = i + + for i in range(l_tgt + 1): + a[0][i] = i * self.config.insertion_weight + + def strip_vowels(s): + return re.sub('[aeoiu]', '', s) + + words = target.split() + word_bonus = min(self.config.word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words), + self.config.word_match_weight * max(sum(a == b for a, b in + zip(source, w[0] + strip_vowels(w[1:]))) for w in + words), + self.config.acronym_match_weight * sum( + a == b for a, b in zip(source, ''.join(w[0] for w in words)))) + + def sub_weight_at(n, m): + if source[n - 1] != target[m - 1]: + return self.config.special_substitution_weights.get( + (source[n - 1], target[m - 1]), + self.config.default_substitution_weight + ) + else: + return self.config.match_weight + + for i_src in range(1, l_src + 1): + for i_tgt in range(1, l_tgt + 1): + a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + sub_weight_at(i_src, i_tgt), + a[i_src - 1][i_tgt] + self.config.deletion_weight, + a[i_src][i_tgt - 1] + self.config.insertion_weight) + + # there are l_scr - i_src source chars remaining + # each match removes the insertion weight then adds the match weight + # (l_src - i_src) * (self.config.match_weight - self.config.insertion_weight) + # is the max difference that can make + max_additional_score = ((l_src - i_src) * (self.config.match_weight - self.config.insertion_weight) + + word_bonus + self.config.base_score) + if i_tgt == l_tgt and ( + a[i_src][i_tgt] + max_additional_score) > 0 and \ + (a[i_src][i_tgt - 1] + max_additional_score) > 0: + return 1 + + return a[l_src][l_tgt] + word_bonus + self.config.base_score + + +def romanize(s: str) -> str: kks = pykakasi.kakasi() s = re.sub('[\']', '', s) + s = re.sub('[・]', ' ', s) s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s) + s = re.sub('[0-9]+', lambda ele: f' {ele[0]} ', s) s = ' '.join(c['hepburn'].strip().lower() for c in kks.convert(s)) s = re.sub(r'[^a-zA-Z0-9_ ]+', '', s) - words = tuple(s.split()) - return ''.join(words), words + return ' '.join(s.split())