add paged song search command

pull/1/head
qwewqa 4 years ago
parent a1123a0582
commit 7b05d66f4a
  1. 22
      miyu_bot/commands/cogs/music.py
  2. 69
      miyu_bot/commands/common/reaction_message.py

@ -14,7 +14,7 @@ from main import asset_manager
from miyu_bot.commands.common.emoji import difficulty_emoji_ids from miyu_bot.commands.common.emoji import difficulty_emoji_ids
from miyu_bot.commands.common.formatting import format_info from miyu_bot.commands.common.formatting import format_info
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap
from miyu_bot.commands.common.reaction_message import run_tabbed_message from miyu_bot.commands.common.reaction_message import run_tabbed_message, run_paged_message
class Music(commands.Cog): class Music(commands.Cog):
@ -145,12 +145,24 @@ class Music(commands.Cog):
asyncio.ensure_future(run_tabbed_message(ctx, message, self.reaction_emojis, embeds)) asyncio.ensure_future(run_tabbed_message(ctx, message, self.reaction_emojis, embeds))
@commands.command(name='songs', @commands.command(name='songs',
aliases=['search_songs'], aliases=['songsearch', 'song_search'],
description='Finds songs matching the given name.', description='Finds songs matching the given name.',
help='!songs grgr') help='!songs grgr')
async def songs(self, ctx: commands.Context, *, arg: str): async def songs(self, ctx: commands.Context, *, arg: str = ""):
self.logger.info(f'Searching for songs sections "{arg}".') if arg:
songs = self.music.get_sorted(arg) self.logger.info(f'Searching for songs "{arg}".')
songs = self.music.get_sorted(arg)
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
songs]
asyncio.ensure_future(run_paged_message(ctx, f'Song Search "{arg}"', listing))
else:
self.logger.info('Listing songs.')
songs = sorted(self.music.values(), key=lambda m: -m.default_order)
songs = [*songs[1:], songs[0]] # lesson is always first
listing = [f'{song.name}{" (" + song.special_unit_name + ")" if song.special_unit_name else ""}' for song in
songs]
asyncio.ensure_future(run_paged_message(ctx, f'All Songs', listing))
return return
def get_chart_embed_info(self, song): def get_chart_embed_info(self, song):

@ -1,10 +1,12 @@
import asyncio import asyncio
from typing import List, Callable, Awaitable from typing import List, Callable, Awaitable, Union, Optional
import discord import discord
from discord import Message, Embed, Emoji from discord import Message, Embed, Emoji, PartialEmoji
from discord.ext.commands import Context from discord.ext.commands import Context
AnyEmoji = Union[str, Emoji]
async def run_tabbed_message(ctx: Context, message: Message, emojis: List[Emoji], embeds: List[Embed], timeout=300): async def run_tabbed_message(ctx: Context, message: Message, emojis: List[Emoji], embeds: List[Embed], timeout=300):
async def callback(emoji, _ctx, _message): async def callback(emoji, _ctx, _message):
@ -13,8 +15,65 @@ async def run_tabbed_message(ctx: Context, message: Message, emojis: List[Emoji]
await run_reaction_message(ctx, message, emojis, callback, timeout) await run_reaction_message(ctx, message, emojis, callback, timeout)
async def run_reaction_message(ctx: Context, message: Message, emojis: List[Emoji], async def run_paged_message(ctx: Context, title: str, content: List[str], page_size: int = 15,
callback: Callable[[Emoji, Context, Message], Awaitable[None]], timeout=300): timeout=300, double_arrow_threshold=4):
if not content:
embed = discord.Embed(title=title).set_footer(text='Page 0/0')
await ctx.send(embed=embed)
return
page_contents = [content[i:i + page_size] for i in range(0, len(content), page_size)]
item_number = 0
def format_item(item):
nonlocal item_number
item_number += 1
return f'{item_number}. {item}'
embeds = [
discord.Embed(title=title, description='\n'.join((format_item(i) for i in page))).set_footer(
text=f'Page {i + 1}/{len(page_contents)}')
for i, page in enumerate(page_contents)]
message = await ctx.send(embed=embeds[0])
if len(embeds) == 1:
return
double_left_arrow = ''
double_right_arrow = ''
left_arrow = ''
right_arrow = ''
if 0 < double_arrow_threshold <= len(embeds):
arrows = [double_left_arrow, left_arrow, right_arrow, double_right_arrow]
else:
arrows = [left_arrow, right_arrow]
index = 0
async def callback(emoji, _ctx, _message):
nonlocal index
start_index = index
if emoji == double_left_arrow:
index = 0
elif emoji == left_arrow:
index -= 1
elif emoji == right_arrow:
index += 1
elif emoji == double_right_arrow:
index = len(embeds) - 1
index = min(len(embeds) - 1, max(0, index))
if index != start_index:
await message.edit(embed=embeds[index])
await run_reaction_message(ctx, message, arrows, callback, timeout)
async def run_reaction_message(ctx: Context, message: Message, emojis: List[AnyEmoji],
callback: Callable[[AnyEmoji, Context, Message], Awaitable[None]], timeout=300):
for emoji in emojis: for emoji in emojis:
await message.add_reaction(emoji) await message.add_reaction(emoji)
@ -28,5 +87,5 @@ async def run_reaction_message(ctx: Context, message: Message, emojis: List[Emoj
await message.remove_reaction(reaction, user) await message.remove_reaction(reaction, user)
except asyncio.TimeoutError: except asyncio.TimeoutError:
for emoji in emojis: for emoji in emojis:
await message.remove_reaction(ctx.bot.get_emoji(emoji), ctx.bot.user) await message.remove_reaction(emoji, ctx.bot.user)
break break

Loading…
Cancel
Save