allow disabling master filters in some channels

pull/1/head
qwewqa 4 years ago
parent 88137751ff
commit ea31a2e25c
  1. 28
      miyu_bot/commands/cogs/event.py
  2. 31
      miyu_bot/commands/cogs/music.py
  3. 11
      miyu_bot/commands/common/fuzzy_matching.py
  4. 49
      miyu_bot/commands/common/master_asset_manager.py

@ -10,18 +10,19 @@ from miyu_bot.commands.common.emoji import attribute_emoji_ids_by_attribute_id,
parameter_bonus_emoji_ids_by_parameter_id, \ parameter_bonus_emoji_ids_by_parameter_id, \
event_point_emoji_id event_point_emoji_id
from miyu_bot.commands.common.formatting import format_info from miyu_bot.commands.common.formatting import format_info
from miyu_bot.commands.common.fuzzy_matching import FuzzyMap, romanize from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap, romanize
from miyu_bot.commands.common.master_asset_manager import MasterAssetManager
class Event(commands.Cog): class Event(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.events = FuzzyMap( self.events = MasterAssetManager(
lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8) asset_manager.event_master,
naming_function=lambda e: e.name,
filter_function=lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8),
) )
for e in asset_manager.event_master.values():
self.events[e.name] = e
@commands.command(name='event', @commands.command(name='event',
aliases=['ev'], aliases=['ev'],
@ -32,14 +33,9 @@ class Event(commands.Cog):
event: EventMaster event: EventMaster
if arg: if arg:
try: event = self.events.get(arg, ctx)
event = asset_manager.event_master[int(arg)]
if event not in self.events.values():
event = self.events[arg]
except (ValueError, KeyError):
event = self.events[arg]
else: else:
event = self.get_latest_event() event = self.get_latest_event(ctx)
if not event: if not event:
msg = f'Failed to find event "{arg}".' msg = f'Failed to find event "{arg}".'
@ -107,7 +103,7 @@ class Event(commands.Cog):
description='Displays the time left in the current event', description='Displays the time left in the current event',
help='!timeleft') help='!timeleft')
async def time_left(self, ctx: commands.Context): async def time_left(self, ctx: commands.Context):
latest = self.get_latest_event() latest = self.get_latest_event(ctx)
state = latest.state() state = latest.state()
@ -161,13 +157,13 @@ class Event(commands.Cog):
await ctx.send(files=[logo], embed=embed) await ctx.send(files=[logo], embed=embed)
def get_latest_event(self) -> EventMaster: def get_latest_event(self, ctx: commands.Context) -> EventMaster:
"""Returns the oldest event that has not ended or the newest event otherwise.""" """Returns the oldest event that has not ended or the newest event otherwise."""
try: try:
return min((v for v in self.events.values() if v.state() < EventState.Ended), return min((v for v in self.events.values(ctx) if v.state() < EventState.Ended),
key=lambda e: e.start_datetime) key=lambda e: e.start_datetime)
except ValueError: except ValueError:
return max(self.events.values(), key=lambda v: v.start_datetime) return max(self.events.values(ctx), key=lambda v: v.start_datetime)
def setup(bot): def setup(bot):

@ -13,7 +13,8 @@ from discord.ext import commands
from main import asset_manager from main import asset_manager
from miyu_bot.commands.common.emoji import difficulty_emoji_ids from miyu_bot.commands.common.emoji import difficulty_emoji_ids
from miyu_bot.commands.common.formatting import format_info from miyu_bot.commands.common.formatting import format_info
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyFilteredMap
from miyu_bot.commands.common.master_asset_manager import MasterAssetManager
from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message
@ -21,10 +22,11 @@ class Music(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.music = FuzzyMap(lambda m: m.is_released) self.music = MasterAssetManager(
for m in asset_manager.music_master.values(): asset_manager.music_master,
if not self.music.has_exact(f'{m.name} {m.special_unit_name}'): naming_function=lambda m: f'{m.name} {m.special_unit_name}',
self.music[f'{m.name} {m.special_unit_name}'] = m filter_function=lambda m: m.is_released,
)
@property @property
def reaction_emojis(self): def reaction_emojis(self):
@ -52,7 +54,7 @@ class Music(commands.Cog):
async def song(self, ctx: commands.Context, *, arg: str): async def song(self, ctx: commands.Context, *, arg: str):
self.logger.info(f'Searching for song "{arg}".') self.logger.info(f'Searching for song "{arg}".')
song = self.get_song(arg) song = self.music.get(arg, ctx)
if not song: if not song:
msg = f'Failed to find song "{arg}".' msg = f'Failed to find song "{arg}".'
@ -101,7 +103,7 @@ class Music(commands.Cog):
self.logger.info(f'Searching for chart "{arg}".') self.logger.info(f'Searching for chart "{arg}".')
name, difficulty = self.parse_chart_args(arg) name, difficulty = self.parse_chart_args(arg)
song = self.get_song(name) song = self.music.get(name, ctx)
if not song: if not song:
msg = f'Failed to find chart "{name}".' msg = f'Failed to find chart "{name}".'
@ -124,7 +126,7 @@ class Music(commands.Cog):
self.logger.info(f'Searching for chart sections "{arg}".') self.logger.info(f'Searching for chart sections "{arg}".')
name, difficulty = self.parse_chart_args(arg) name, difficulty = self.parse_chart_args(arg)
song = self.get_song(name) song = self.music.get(name, ctx)
if not song: if not song:
msg = f'Failed to find chart "{name}".' msg = f'Failed to find chart "{name}".'
@ -151,13 +153,13 @@ class Music(commands.Cog):
async def songs(self, ctx: commands.Context, *, arg: str = ""): async def songs(self, ctx: commands.Context, *, arg: str = ""):
if arg: if arg:
self.logger.info(f'Searching for songs "{arg}".') self.logger.info(f'Searching for songs "{arg}".')
songs = self.music.get_sorted(arg) songs = self.music.get_sorted(arg, ctx)
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
songs] songs]
asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing)) asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing))
else: else:
self.logger.info('Listing songs.') self.logger.info('Listing songs.')
songs = sorted(self.music.values(), key=lambda m: -m.default_order) songs = sorted(self.music.values(ctx), key=lambda m: -m.default_order)
songs = [*songs[1:], songs[0]] # lesson is always first songs = [*songs[1:], songs[0]] # lesson is always first
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
songs] songs]
@ -271,15 +273,6 @@ class Music(commands.Cog):
arg = ''.join(split_args[:-1]) arg = ''.join(split_args[:-1])
return arg, difficulty return arg, difficulty
def get_song(self, name_or_id: str) -> Optional[MusicMaster]:
try:
song = asset_manager.music_master[int(name_or_id)]
if song not in self.music.values():
song = self.music[name_or_id]
return song
except (KeyError, ValueError):
return self.music[name_or_id]
def get_music_duration(self, music: MusicMaster): def get_music_duration(self, music: MusicMaster):
with contextlib.closing(wave.open(str(music.audio_path.with_name(music.audio_path.name + '.wav')), 'r')) as f: with contextlib.closing(wave.open(str(music.audio_path.with_name(music.audio_path.name + '.wav')), 'r')) as f:
frames = f.getnframes() frames = f.getnframes()

@ -9,9 +9,9 @@ from typing import Dict, Tuple, List, Optional, Iterable
import pykakasi import pykakasi
class FuzzyMap: class FuzzyFilteredMap:
def __init__(self, filter=None, matcher=None, additive_only_filter=True): def __init__(self, filter_function=None, matcher=None, additive_only_filter=True):
self.filter = filter or (lambda n: True) self.filter = filter_function or (lambda n: True)
self.matcher = matcher or FuzzyMatcher() self.matcher = matcher or FuzzyMatcher()
self._values = {} self._values = {}
self.max_length = 0 self.max_length = 0
@ -56,10 +56,10 @@ class FuzzyMap:
def __getitem__(self, key): def __getitem__(self, key):
start_time = timeit.default_timer() start_time = timeit.default_timer()
key = romanize(key)
if len(key) > self.max_length: if len(key) > self.max_length:
self.logger.debug(f'Rejected key "{key}" due to length.') self.logger.debug(f'Rejected key "{key}" due to length.')
return None return None
key = romanize(key)
try: try:
matcher = self.matcher matcher = self.matcher
result = min((score, item) for score, item in result = min((score, item) for score, item in
@ -86,7 +86,7 @@ class FuzzyMap:
class FuzzyDictValuesView: class FuzzyDictValuesView:
def __init__(self, map: FuzzyMap): def __init__(self, map: FuzzyFilteredMap):
self._map = map self._map = map
def __contains__(self, item): def __contains__(self, item):
@ -197,6 +197,7 @@ def strip_vowels(s):
def romanize(s: str) -> str: def romanize(s: str) -> str:
kks = pykakasi.kakasi() kks = pykakasi.kakasi()
s = str(s)
s = re.sub('[\']', '', s) s = re.sub('[\']', '', s)
s = re.sub('[・]', ' ', s) s = re.sub('[・]', ' ', s)
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s) s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)

@ -0,0 +1,49 @@
from typing import Callable, Any
from d4dj_utils.master.master_asset import MasterDict
from discord.ext import commands
from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap
class MasterAssetManager:
def __init__(self, masters: MasterDict, naming_function: Callable[[Any], str], filter_function=lambda _: True):
self.masters = masters
self.fuzzy_map = FuzzyFilteredMap(filter_function)
self.unfiltered_fuzzy_map = FuzzyFilteredMap()
for master in masters.values():
name = naming_function(master)
if self.fuzzy_map.has_exact(name):
continue
self.fuzzy_map[name] = master
self.unfiltered_fuzzy_map[name] = master
def get(self, name_or_id: str, ctx: commands.Context):
if ctx.channel.id in no_filter_channels:
try:
return self.masters[int(name_or_id)]
except (KeyError, ValueError):
return self.unfiltered_fuzzy_map[name_or_id]
else:
try:
master = self.masters[int(name_or_id)]
if master not in self.fuzzy_map.values():
master = self.fuzzy_map[name_or_id]
return master
except (KeyError, ValueError):
return self.fuzzy_map[name_or_id]
def get_sorted(self, name: str, ctx: commands.Context):
if ctx.channel.id in no_filter_channels:
return self.unfiltered_fuzzy_map.get_sorted(name)
else:
return self.fuzzy_map.get_sorted(name)
def values(self, ctx: commands.Context):
if ctx.channel.id in no_filter_channels:
return self.unfiltered_fuzzy_map.values()
else:
return self.fuzzy_map.values()
no_filter_channels = {790033228600705048, 790033272376918027}
Loading…
Cancel
Save