From ea31a2e25ca3610761bb5ec0519bdc030c14be9a Mon Sep 17 00:00:00 2001 From: qwewqa <198e559dbd446d973355f415bdfa34@gmail.com> Date: Mon, 28 Dec 2020 20:56:51 -0500 Subject: [PATCH] allow disabling master filters in some channels --- miyu_bot/commands/cogs/event.py | 28 +++++------ miyu_bot/commands/cogs/music.py | 31 +++++------- miyu_bot/commands/common/fuzzy_matching.py | 11 +++-- .../commands/common/master_asset_manager.py | 49 +++++++++++++++++++ 4 files changed, 79 insertions(+), 40 deletions(-) create mode 100644 miyu_bot/commands/common/master_asset_manager.py diff --git a/miyu_bot/commands/cogs/event.py b/miyu_bot/commands/cogs/event.py index 2d4cb7b..1a44605 100644 --- a/miyu_bot/commands/cogs/event.py +++ b/miyu_bot/commands/cogs/event.py @@ -10,18 +10,19 @@ from miyu_bot.commands.common.emoji import attribute_emoji_ids_by_attribute_id, parameter_bonus_emoji_ids_by_parameter_id, \ event_point_emoji_id from miyu_bot.commands.common.formatting import format_info -from miyu_bot.commands.common.fuzzy_matching import FuzzyMap, romanize +from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap, romanize +from miyu_bot.commands.common.master_asset_manager import MasterAssetManager class Event(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.logger = logging.getLogger(__name__) - self.events = FuzzyMap( - lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8) + self.events = MasterAssetManager( + asset_manager.event_master, + naming_function=lambda e: e.name, + filter_function=lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8), ) - for e in asset_manager.event_master.values(): - self.events[e.name] = e @commands.command(name='event', aliases=['ev'], @@ -32,14 +33,9 @@ class Event(commands.Cog): event: EventMaster if arg: - try: - event = asset_manager.event_master[int(arg)] - if event not in self.events.values(): - event = self.events[arg] - except (ValueError, KeyError): - event = self.events[arg] + event = self.events.get(arg, ctx) else: - event = self.get_latest_event() + event = self.get_latest_event(ctx) if not event: msg = f'Failed to find event "{arg}".' @@ -107,7 +103,7 @@ class Event(commands.Cog): description='Displays the time left in the current event', help='!timeleft') async def time_left(self, ctx: commands.Context): - latest = self.get_latest_event() + latest = self.get_latest_event(ctx) state = latest.state() @@ -161,13 +157,13 @@ class Event(commands.Cog): await ctx.send(files=[logo], embed=embed) - def get_latest_event(self) -> EventMaster: + def get_latest_event(self, ctx: commands.Context) -> EventMaster: """Returns the oldest event that has not ended or the newest event otherwise.""" try: - return min((v for v in self.events.values() if v.state() < EventState.Ended), + return min((v for v in self.events.values(ctx) if v.state() < EventState.Ended), key=lambda e: e.start_datetime) except ValueError: - return max(self.events.values(), key=lambda v: v.start_datetime) + return max(self.events.values(ctx), key=lambda v: v.start_datetime) def setup(bot): diff --git a/miyu_bot/commands/cogs/music.py b/miyu_bot/commands/cogs/music.py index 73d4990..100e374 100644 --- a/miyu_bot/commands/cogs/music.py +++ b/miyu_bot/commands/cogs/music.py @@ -13,7 +13,8 @@ from discord.ext import commands from main import asset_manager from miyu_bot.commands.common.emoji import difficulty_emoji_ids from miyu_bot.commands.common.formatting import format_info -from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap +from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyFilteredMap +from miyu_bot.commands.common.master_asset_manager import MasterAssetManager from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message @@ -21,10 +22,11 @@ class Music(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.logger = logging.getLogger(__name__) - self.music = FuzzyMap(lambda m: m.is_released) - for m in asset_manager.music_master.values(): - if not self.music.has_exact(f'{m.name} {m.special_unit_name}'): - self.music[f'{m.name} {m.special_unit_name}'] = m + self.music = MasterAssetManager( + asset_manager.music_master, + naming_function=lambda m: f'{m.name} {m.special_unit_name}', + filter_function=lambda m: m.is_released, + ) @property def reaction_emojis(self): @@ -52,7 +54,7 @@ class Music(commands.Cog): async def song(self, ctx: commands.Context, *, arg: str): self.logger.info(f'Searching for song "{arg}".') - song = self.get_song(arg) + song = self.music.get(arg, ctx) if not song: msg = f'Failed to find song "{arg}".' @@ -101,7 +103,7 @@ class Music(commands.Cog): self.logger.info(f'Searching for chart "{arg}".') name, difficulty = self.parse_chart_args(arg) - song = self.get_song(name) + song = self.music.get(name, ctx) if not song: msg = f'Failed to find chart "{name}".' @@ -124,7 +126,7 @@ class Music(commands.Cog): self.logger.info(f'Searching for chart sections "{arg}".') name, difficulty = self.parse_chart_args(arg) - song = self.get_song(name) + song = self.music.get(name, ctx) if not song: msg = f'Failed to find chart "{name}".' @@ -151,13 +153,13 @@ class Music(commands.Cog): async def songs(self, ctx: commands.Context, *, arg: str = ""): if arg: self.logger.info(f'Searching for songs "{arg}".') - songs = self.music.get_sorted(arg) + songs = self.music.get_sorted(arg, ctx) listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in songs] asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing)) else: self.logger.info('Listing songs.') - songs = sorted(self.music.values(), key=lambda m: -m.default_order) + songs = sorted(self.music.values(ctx), key=lambda m: -m.default_order) songs = [*songs[1:], songs[0]] # lesson is always first listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in songs] @@ -271,15 +273,6 @@ class Music(commands.Cog): arg = ''.join(split_args[:-1]) return arg, difficulty - def get_song(self, name_or_id: str) -> Optional[MusicMaster]: - try: - song = asset_manager.music_master[int(name_or_id)] - if song not in self.music.values(): - song = self.music[name_or_id] - return song - except (KeyError, ValueError): - return self.music[name_or_id] - def get_music_duration(self, music: MusicMaster): with contextlib.closing(wave.open(str(music.audio_path.with_name(music.audio_path.name + '.wav')), 'r')) as f: frames = f.getnframes() diff --git a/miyu_bot/commands/common/fuzzy_matching.py b/miyu_bot/commands/common/fuzzy_matching.py index dee5270..6070ae4 100644 --- a/miyu_bot/commands/common/fuzzy_matching.py +++ b/miyu_bot/commands/common/fuzzy_matching.py @@ -9,9 +9,9 @@ from typing import Dict, Tuple, List, Optional, Iterable import pykakasi -class FuzzyMap: - def __init__(self, filter=None, matcher=None, additive_only_filter=True): - self.filter = filter or (lambda n: True) +class FuzzyFilteredMap: + def __init__(self, filter_function=None, matcher=None, additive_only_filter=True): + self.filter = filter_function or (lambda n: True) self.matcher = matcher or FuzzyMatcher() self._values = {} self.max_length = 0 @@ -56,10 +56,10 @@ class FuzzyMap: def __getitem__(self, key): start_time = timeit.default_timer() + key = romanize(key) if len(key) > self.max_length: self.logger.debug(f'Rejected key "{key}" due to length.') return None - key = romanize(key) try: matcher = self.matcher result = min((score, item) for score, item in @@ -86,7 +86,7 @@ class FuzzyMap: class FuzzyDictValuesView: - def __init__(self, map: FuzzyMap): + def __init__(self, map: FuzzyFilteredMap): self._map = map def __contains__(self, item): @@ -197,6 +197,7 @@ def strip_vowels(s): def romanize(s: str) -> str: kks = pykakasi.kakasi() + s = str(s) s = re.sub('[\']', '', s) s = re.sub('[・]', ' ', s) s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s) diff --git a/miyu_bot/commands/common/master_asset_manager.py b/miyu_bot/commands/common/master_asset_manager.py new file mode 100644 index 0000000..18f8148 --- /dev/null +++ b/miyu_bot/commands/common/master_asset_manager.py @@ -0,0 +1,49 @@ +from typing import Callable, Any + +from d4dj_utils.master.master_asset import MasterDict +from discord.ext import commands + +from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap + + +class MasterAssetManager: + def __init__(self, masters: MasterDict, naming_function: Callable[[Any], str], filter_function=lambda _: True): + self.masters = masters + self.fuzzy_map = FuzzyFilteredMap(filter_function) + self.unfiltered_fuzzy_map = FuzzyFilteredMap() + for master in masters.values(): + name = naming_function(master) + if self.fuzzy_map.has_exact(name): + continue + self.fuzzy_map[name] = master + self.unfiltered_fuzzy_map[name] = master + + def get(self, name_or_id: str, ctx: commands.Context): + if ctx.channel.id in no_filter_channels: + try: + return self.masters[int(name_or_id)] + except (KeyError, ValueError): + return self.unfiltered_fuzzy_map[name_or_id] + else: + try: + master = self.masters[int(name_or_id)] + if master not in self.fuzzy_map.values(): + master = self.fuzzy_map[name_or_id] + return master + except (KeyError, ValueError): + return self.fuzzy_map[name_or_id] + + def get_sorted(self, name: str, ctx: commands.Context): + if ctx.channel.id in no_filter_channels: + return self.unfiltered_fuzzy_map.get_sorted(name) + else: + return self.fuzzy_map.get_sorted(name) + + def values(self, ctx: commands.Context): + if ctx.channel.id in no_filter_channels: + return self.unfiltered_fuzzy_map.values() + else: + return self.fuzzy_map.values() + + +no_filter_channels = {790033228600705048, 790033272376918027}