allow disabling master filters in some channels
This commit is contained in:
parent
88137751ff
commit
ea31a2e25c
@ -10,18 +10,19 @@ from miyu_bot.commands.common.emoji import attribute_emoji_ids_by_attribute_id,
|
|||||||
parameter_bonus_emoji_ids_by_parameter_id, \
|
parameter_bonus_emoji_ids_by_parameter_id, \
|
||||||
event_point_emoji_id
|
event_point_emoji_id
|
||||||
from miyu_bot.commands.common.formatting import format_info
|
from miyu_bot.commands.common.formatting import format_info
|
||||||
from miyu_bot.commands.common.fuzzy_matching import FuzzyMap, romanize
|
from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap, romanize
|
||||||
|
from miyu_bot.commands.common.master_asset_manager import MasterAssetManager
|
||||||
|
|
||||||
|
|
||||||
class Event(commands.Cog):
|
class Event(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.events = FuzzyMap(
|
self.events = MasterAssetManager(
|
||||||
lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8)
|
asset_manager.event_master,
|
||||||
|
naming_function=lambda e: e.name,
|
||||||
|
filter_function=lambda e: e.start_datetime < datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8),
|
||||||
)
|
)
|
||||||
for e in asset_manager.event_master.values():
|
|
||||||
self.events[e.name] = e
|
|
||||||
|
|
||||||
@commands.command(name='event',
|
@commands.command(name='event',
|
||||||
aliases=['ev'],
|
aliases=['ev'],
|
||||||
@ -32,14 +33,9 @@ class Event(commands.Cog):
|
|||||||
|
|
||||||
event: EventMaster
|
event: EventMaster
|
||||||
if arg:
|
if arg:
|
||||||
try:
|
event = self.events.get(arg, ctx)
|
||||||
event = asset_manager.event_master[int(arg)]
|
|
||||||
if event not in self.events.values():
|
|
||||||
event = self.events[arg]
|
|
||||||
except (ValueError, KeyError):
|
|
||||||
event = self.events[arg]
|
|
||||||
else:
|
else:
|
||||||
event = self.get_latest_event()
|
event = self.get_latest_event(ctx)
|
||||||
|
|
||||||
if not event:
|
if not event:
|
||||||
msg = f'Failed to find event "{arg}".'
|
msg = f'Failed to find event "{arg}".'
|
||||||
@ -107,7 +103,7 @@ class Event(commands.Cog):
|
|||||||
description='Displays the time left in the current event',
|
description='Displays the time left in the current event',
|
||||||
help='!timeleft')
|
help='!timeleft')
|
||||||
async def time_left(self, ctx: commands.Context):
|
async def time_left(self, ctx: commands.Context):
|
||||||
latest = self.get_latest_event()
|
latest = self.get_latest_event(ctx)
|
||||||
|
|
||||||
state = latest.state()
|
state = latest.state()
|
||||||
|
|
||||||
@ -161,13 +157,13 @@ class Event(commands.Cog):
|
|||||||
|
|
||||||
await ctx.send(files=[logo], embed=embed)
|
await ctx.send(files=[logo], embed=embed)
|
||||||
|
|
||||||
def get_latest_event(self) -> EventMaster:
|
def get_latest_event(self, ctx: commands.Context) -> EventMaster:
|
||||||
"""Returns the oldest event that has not ended or the newest event otherwise."""
|
"""Returns the oldest event that has not ended or the newest event otherwise."""
|
||||||
try:
|
try:
|
||||||
return min((v for v in self.events.values() if v.state() < EventState.Ended),
|
return min((v for v in self.events.values(ctx) if v.state() < EventState.Ended),
|
||||||
key=lambda e: e.start_datetime)
|
key=lambda e: e.start_datetime)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return max(self.events.values(), key=lambda v: v.start_datetime)
|
return max(self.events.values(ctx), key=lambda v: v.start_datetime)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
@ -13,7 +13,8 @@ from discord.ext import commands
|
|||||||
from main import asset_manager
|
from main import asset_manager
|
||||||
from miyu_bot.commands.common.emoji import difficulty_emoji_ids
|
from miyu_bot.commands.common.emoji import difficulty_emoji_ids
|
||||||
from miyu_bot.commands.common.formatting import format_info
|
from miyu_bot.commands.common.formatting import format_info
|
||||||
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap
|
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyFilteredMap
|
||||||
|
from miyu_bot.commands.common.master_asset_manager import MasterAssetManager
|
||||||
from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message
|
from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message
|
||||||
|
|
||||||
|
|
||||||
@ -21,10 +22,11 @@ class Music(commands.Cog):
|
|||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.music = FuzzyMap(lambda m: m.is_released)
|
self.music = MasterAssetManager(
|
||||||
for m in asset_manager.music_master.values():
|
asset_manager.music_master,
|
||||||
if not self.music.has_exact(f'{m.name} {m.special_unit_name}'):
|
naming_function=lambda m: f'{m.name} {m.special_unit_name}',
|
||||||
self.music[f'{m.name} {m.special_unit_name}'] = m
|
filter_function=lambda m: m.is_released,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reaction_emojis(self):
|
def reaction_emojis(self):
|
||||||
@ -52,7 +54,7 @@ class Music(commands.Cog):
|
|||||||
async def song(self, ctx: commands.Context, *, arg: str):
|
async def song(self, ctx: commands.Context, *, arg: str):
|
||||||
self.logger.info(f'Searching for song "{arg}".')
|
self.logger.info(f'Searching for song "{arg}".')
|
||||||
|
|
||||||
song = self.get_song(arg)
|
song = self.music.get(arg, ctx)
|
||||||
|
|
||||||
if not song:
|
if not song:
|
||||||
msg = f'Failed to find song "{arg}".'
|
msg = f'Failed to find song "{arg}".'
|
||||||
@ -101,7 +103,7 @@ class Music(commands.Cog):
|
|||||||
self.logger.info(f'Searching for chart "{arg}".')
|
self.logger.info(f'Searching for chart "{arg}".')
|
||||||
|
|
||||||
name, difficulty = self.parse_chart_args(arg)
|
name, difficulty = self.parse_chart_args(arg)
|
||||||
song = self.get_song(name)
|
song = self.music.get(name, ctx)
|
||||||
|
|
||||||
if not song:
|
if not song:
|
||||||
msg = f'Failed to find chart "{name}".'
|
msg = f'Failed to find chart "{name}".'
|
||||||
@ -124,7 +126,7 @@ class Music(commands.Cog):
|
|||||||
self.logger.info(f'Searching for chart sections "{arg}".')
|
self.logger.info(f'Searching for chart sections "{arg}".')
|
||||||
|
|
||||||
name, difficulty = self.parse_chart_args(arg)
|
name, difficulty = self.parse_chart_args(arg)
|
||||||
song = self.get_song(name)
|
song = self.music.get(name, ctx)
|
||||||
|
|
||||||
if not song:
|
if not song:
|
||||||
msg = f'Failed to find chart "{name}".'
|
msg = f'Failed to find chart "{name}".'
|
||||||
@ -151,13 +153,13 @@ class Music(commands.Cog):
|
|||||||
async def songs(self, ctx: commands.Context, *, arg: str = ""):
|
async def songs(self, ctx: commands.Context, *, arg: str = ""):
|
||||||
if arg:
|
if arg:
|
||||||
self.logger.info(f'Searching for songs "{arg}".')
|
self.logger.info(f'Searching for songs "{arg}".')
|
||||||
songs = self.music.get_sorted(arg)
|
songs = self.music.get_sorted(arg, ctx)
|
||||||
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
|
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
|
||||||
songs]
|
songs]
|
||||||
asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing))
|
asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing))
|
||||||
else:
|
else:
|
||||||
self.logger.info('Listing songs.')
|
self.logger.info('Listing songs.')
|
||||||
songs = sorted(self.music.values(), key=lambda m: -m.default_order)
|
songs = sorted(self.music.values(ctx), key=lambda m: -m.default_order)
|
||||||
songs = [*songs[1:], songs[0]] # lesson is always first
|
songs = [*songs[1:], songs[0]] # lesson is always first
|
||||||
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
|
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
|
||||||
songs]
|
songs]
|
||||||
@ -271,15 +273,6 @@ class Music(commands.Cog):
|
|||||||
arg = ''.join(split_args[:-1])
|
arg = ''.join(split_args[:-1])
|
||||||
return arg, difficulty
|
return arg, difficulty
|
||||||
|
|
||||||
def get_song(self, name_or_id: str) -> Optional[MusicMaster]:
|
|
||||||
try:
|
|
||||||
song = asset_manager.music_master[int(name_or_id)]
|
|
||||||
if song not in self.music.values():
|
|
||||||
song = self.music[name_or_id]
|
|
||||||
return song
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
return self.music[name_or_id]
|
|
||||||
|
|
||||||
def get_music_duration(self, music: MusicMaster):
|
def get_music_duration(self, music: MusicMaster):
|
||||||
with contextlib.closing(wave.open(str(music.audio_path.with_name(music.audio_path.name + '.wav')), 'r')) as f:
|
with contextlib.closing(wave.open(str(music.audio_path.with_name(music.audio_path.name + '.wav')), 'r')) as f:
|
||||||
frames = f.getnframes()
|
frames = f.getnframes()
|
||||||
|
@ -9,9 +9,9 @@ from typing import Dict, Tuple, List, Optional, Iterable
|
|||||||
import pykakasi
|
import pykakasi
|
||||||
|
|
||||||
|
|
||||||
class FuzzyMap:
|
class FuzzyFilteredMap:
|
||||||
def __init__(self, filter=None, matcher=None, additive_only_filter=True):
|
def __init__(self, filter_function=None, matcher=None, additive_only_filter=True):
|
||||||
self.filter = filter or (lambda n: True)
|
self.filter = filter_function or (lambda n: True)
|
||||||
self.matcher = matcher or FuzzyMatcher()
|
self.matcher = matcher or FuzzyMatcher()
|
||||||
self._values = {}
|
self._values = {}
|
||||||
self.max_length = 0
|
self.max_length = 0
|
||||||
@ -56,10 +56,10 @@ class FuzzyMap:
|
|||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
start_time = timeit.default_timer()
|
start_time = timeit.default_timer()
|
||||||
|
key = romanize(key)
|
||||||
if len(key) > self.max_length:
|
if len(key) > self.max_length:
|
||||||
self.logger.debug(f'Rejected key "{key}" due to length.')
|
self.logger.debug(f'Rejected key "{key}" due to length.')
|
||||||
return None
|
return None
|
||||||
key = romanize(key)
|
|
||||||
try:
|
try:
|
||||||
matcher = self.matcher
|
matcher = self.matcher
|
||||||
result = min((score, item) for score, item in
|
result = min((score, item) for score, item in
|
||||||
@ -86,7 +86,7 @@ class FuzzyMap:
|
|||||||
|
|
||||||
|
|
||||||
class FuzzyDictValuesView:
|
class FuzzyDictValuesView:
|
||||||
def __init__(self, map: FuzzyMap):
|
def __init__(self, map: FuzzyFilteredMap):
|
||||||
self._map = map
|
self._map = map
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
@ -197,6 +197,7 @@ def strip_vowels(s):
|
|||||||
|
|
||||||
def romanize(s: str) -> str:
|
def romanize(s: str) -> str:
|
||||||
kks = pykakasi.kakasi()
|
kks = pykakasi.kakasi()
|
||||||
|
s = str(s)
|
||||||
s = re.sub('[\']', '', s)
|
s = re.sub('[\']', '', s)
|
||||||
s = re.sub('[・]', ' ', s)
|
s = re.sub('[・]', ' ', s)
|
||||||
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)
|
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)
|
||||||
|
49
miyu_bot/commands/common/master_asset_manager.py
Normal file
49
miyu_bot/commands/common/master_asset_manager.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
from d4dj_utils.master.master_asset import MasterDict
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from miyu_bot.commands.common.fuzzy_matching import FuzzyFilteredMap
|
||||||
|
|
||||||
|
|
||||||
|
class MasterAssetManager:
|
||||||
|
def __init__(self, masters: MasterDict, naming_function: Callable[[Any], str], filter_function=lambda _: True):
|
||||||
|
self.masters = masters
|
||||||
|
self.fuzzy_map = FuzzyFilteredMap(filter_function)
|
||||||
|
self.unfiltered_fuzzy_map = FuzzyFilteredMap()
|
||||||
|
for master in masters.values():
|
||||||
|
name = naming_function(master)
|
||||||
|
if self.fuzzy_map.has_exact(name):
|
||||||
|
continue
|
||||||
|
self.fuzzy_map[name] = master
|
||||||
|
self.unfiltered_fuzzy_map[name] = master
|
||||||
|
|
||||||
|
def get(self, name_or_id: str, ctx: commands.Context):
|
||||||
|
if ctx.channel.id in no_filter_channels:
|
||||||
|
try:
|
||||||
|
return self.masters[int(name_or_id)]
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
return self.unfiltered_fuzzy_map[name_or_id]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
master = self.masters[int(name_or_id)]
|
||||||
|
if master not in self.fuzzy_map.values():
|
||||||
|
master = self.fuzzy_map[name_or_id]
|
||||||
|
return master
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
return self.fuzzy_map[name_or_id]
|
||||||
|
|
||||||
|
def get_sorted(self, name: str, ctx: commands.Context):
|
||||||
|
if ctx.channel.id in no_filter_channels:
|
||||||
|
return self.unfiltered_fuzzy_map.get_sorted(name)
|
||||||
|
else:
|
||||||
|
return self.fuzzy_map.get_sorted(name)
|
||||||
|
|
||||||
|
def values(self, ctx: commands.Context):
|
||||||
|
if ctx.channel.id in no_filter_channels:
|
||||||
|
return self.unfiltered_fuzzy_map.values()
|
||||||
|
else:
|
||||||
|
return self.fuzzy_map.values()
|
||||||
|
|
||||||
|
|
||||||
|
no_filter_channels = {790033228600705048, 790033272376918027}
|
Loading…
x
Reference in New Issue
Block a user