song command
This commit is contained in:
parent
3cb092abb9
commit
31f64a2799
4
main.py
4
main.py
@ -9,9 +9,11 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
with open('config.json') as f:
|
||||
bot_token = json.load(f)['token']
|
||||
|
||||
bot = commands.Bot(command_prefix='!', case_insensitive=True)
|
||||
asset_manager = AssetManager('assets')
|
||||
bot.load_extension('miyu_bot.commands.cogs.chart')
|
||||
bot.load_extension('miyu_bot.commands.cogs.music')
|
||||
bot.load_extension('miyu_bot.commands.cogs.utility')
|
||||
|
||||
|
||||
@bot.event
|
||||
|
@ -7,17 +7,17 @@ from d4dj_utils.master.music_master import MusicMaster
|
||||
from discord.ext import commands
|
||||
|
||||
from main import asset_manager
|
||||
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMatcher
|
||||
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMap
|
||||
|
||||
|
||||
class Charts(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
class Music(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.music = self.get_music()
|
||||
|
||||
def get_music(self):
|
||||
music = FuzzyMatcher(lambda m: m.is_released)
|
||||
music = FuzzyMap(lambda m: m.is_released)
|
||||
for m in asset_manager.music_master.values():
|
||||
music[f'{m.name} {m.special_unit_name}'] = m
|
||||
return music
|
||||
@ -37,15 +37,62 @@ class Charts(commands.Cog):
|
||||
'es': ChartDifficulty.Easy,
|
||||
}
|
||||
|
||||
@commands.command()
|
||||
async def chart(self, ctx, *, arg):
|
||||
self.logger.info(f'Searching for chart "{arg}".')
|
||||
@staticmethod
|
||||
def format_info(info_entries: dict):
|
||||
return '\n'.join(f'{k}: {v}' for k, v in info_entries.items() if v)
|
||||
|
||||
arg = arg.strip()
|
||||
@commands.command(name='song',
|
||||
aliases=['music'],
|
||||
description='Finds the song with the given name.',
|
||||
help='!song grgr')
|
||||
async def song(self, ctx: commands.Context, *, arg: str):
|
||||
self.logger.info(f'Searching for song "{arg}".')
|
||||
|
||||
if not arg:
|
||||
await ctx.send('Argument is empty.')
|
||||
song: MusicMaster = self.music[arg]
|
||||
if not song:
|
||||
msg = f'Failed to find song "{arg}".'
|
||||
await ctx.send(msg)
|
||||
self.logger.info(msg)
|
||||
return
|
||||
self.logger.info(f'Found "{song}" ({romanize(song.name)[1]}).')
|
||||
|
||||
thumb = discord.File(song.jacket_path, filename='jacket.png')
|
||||
|
||||
embed = discord.Embed(title=song.name)
|
||||
embed.set_thumbnail(url=f'attachment://jacket.png')
|
||||
|
||||
artist_info = {
|
||||
'Lyricist': song.lyricist,
|
||||
'Composer': song.composer,
|
||||
'Arranger': song.arranger,
|
||||
'Unit': song.unit.name,
|
||||
'Special Unit Name': song.special_unit_name,
|
||||
}
|
||||
|
||||
music_info = {
|
||||
'Category': song.category.name,
|
||||
'BPM': song.bpm,
|
||||
'Section Trend': song.section_trend.name,
|
||||
'Sort Order': song.default_order,
|
||||
'Levels': ', '.join(c.display_level for c in song.charts.values()),
|
||||
'Release Date': song.start_datetime,
|
||||
}
|
||||
|
||||
embed.add_field(name='Artist',
|
||||
value=self.format_info(artist_info),
|
||||
inline=False)
|
||||
embed.add_field(name='Info',
|
||||
value=self.format_info(music_info),
|
||||
inline=False)
|
||||
|
||||
await ctx.send(files=[thumb], embed=embed)
|
||||
|
||||
@commands.command(name='chart',
|
||||
aliases=[],
|
||||
description='Finds the chart with the given name.',
|
||||
help='!chart grgr\n!chart grgr normal')
|
||||
async def chart(self, ctx: commands.Context, *, arg: str):
|
||||
self.logger.info(f'Searching for chart "{arg}".')
|
||||
|
||||
split_args = arg.split()
|
||||
|
||||
@ -72,12 +119,12 @@ class Charts(commands.Cog):
|
||||
thumb = discord.File(song.jacket_path, filename='jacket.png')
|
||||
render = discord.File(chart.image_path, filename='render.png')
|
||||
|
||||
embed = discord.Embed(title=song.name)
|
||||
embed = discord.Embed(title=f'{song.name} [{difficulty.name}]')
|
||||
embed.set_thumbnail(url=f'attachment://jacket.png')
|
||||
embed.set_image(url=f'attachment://render.png')
|
||||
|
||||
embed.add_field(name='Info',
|
||||
value=f'Difficulty: {chart.display_level} ({chart.difficulty.name})\n'
|
||||
value=f'Level: {chart.display_level}\n'
|
||||
f'Unit: {song.special_unit_name or song.unit.name}\n'
|
||||
f'Category: {song.category.name}\n'
|
||||
f'BPM: {song.bpm}',
|
||||
@ -96,11 +143,11 @@ class Charts(commands.Cog):
|
||||
f'SCR: {round(chart.trends[2] * 100, 2)}%\n'
|
||||
f'EFT: {round(chart.trends[3] * 100, 2)}%\n'
|
||||
f'TEC: {round(chart.trends[4] * 100, 2)}%\n',
|
||||
inline=True
|
||||
)
|
||||
inline=True)
|
||||
embed.set_footer(text='1 column = 10 seconds')
|
||||
|
||||
await ctx.send(files=[thumb, render], embed=embed)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Charts(bot))
|
||||
bot.add_cog(Music(bot))
|
23
miyu_bot/commands/cogs/utility.py
Normal file
23
miyu_bot/commands/cogs/utility.py
Normal file
@ -0,0 +1,23 @@
|
||||
import logging
|
||||
|
||||
from discord.ext import commands
|
||||
|
||||
from miyu_bot.commands.common.fuzzy_matching import romanize, FuzzyMatcher
|
||||
|
||||
|
||||
class Utility(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
@commands.command(hidden=True)
|
||||
async def romanize(self, ctx: commands.Context, *, arg: str):
|
||||
await ctx.send(romanize(arg))
|
||||
|
||||
@commands.command(hidden=True, ignore_extra=False)
|
||||
async def similarity_score(self, ctx: commands.Context, source: str, target: str):
|
||||
await ctx.send(str(FuzzyMatcher().score(romanize(source), romanize(target))))
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Utility(bot))
|
@ -1,76 +1,118 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
import pykakasi
|
||||
|
||||
|
||||
class FuzzyMatcher:
|
||||
def __init__(self, filter, threshold: float = 1):
|
||||
class FuzzyMap:
|
||||
def __init__(self, filter=lambda: True, matcher=None):
|
||||
self.filter = filter or (lambda n: True)
|
||||
self.threshold = threshold
|
||||
self.values = {}
|
||||
self.matcher = matcher or FuzzyMatcher()
|
||||
self._values = {}
|
||||
self.max_length = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def values(self):
|
||||
return (v for v in self._values.values() if self.filter(v))
|
||||
|
||||
def __delitem__(self, key):
|
||||
k = romanize(key)
|
||||
self._values.__delitem__(k)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
k = romanize(key)
|
||||
self.values[k] = value
|
||||
self.max_length = len(k[0])
|
||||
self._values[k] = value
|
||||
self.max_length = len(k)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if len(key) > self.max_length * 1.1:
|
||||
self.logger.debug(f'Rejected key "{key}" due to length.')
|
||||
return None
|
||||
key, _ = romanize(key)
|
||||
result = min((k for k, v in self.values.items() if self.filter(v)),
|
||||
key=lambda v: fuzzy_match_score(key, *v, threshold=self.threshold))
|
||||
if fuzzy_match_score(key, *result, threshold=self.threshold) > self.threshold:
|
||||
key = romanize(key)
|
||||
result = min((k for k, v in self._values.items() if self.filter(v)), key=lambda k: self.matcher.score(key, k))
|
||||
if self.matcher.score(key, result) > 0:
|
||||
return None
|
||||
return self.values[result]
|
||||
return self._values[result]
|
||||
|
||||
|
||||
_insertion_weight = 0.001
|
||||
_deletion_weight = 1
|
||||
_substitution_weight = 1
|
||||
@dataclass
|
||||
class FuzzyMatchConfig:
|
||||
base_score: float = 0.0
|
||||
insertion_weight: float = 0.001
|
||||
deletion_weight: float = 1.0
|
||||
default_substitution_weight: float = 1.0
|
||||
match_weight: float = -0.2
|
||||
special_substitution_weights: Dict[Tuple[str, str], float] = field(default_factory=lambda: {
|
||||
('v', 'b'): 0.0,
|
||||
('l', 'r'): 0.0,
|
||||
})
|
||||
word_match_weight: float = -0.2
|
||||
acronym_match_weight: float = -0.3
|
||||
|
||||
|
||||
def fuzzy_match_score(source: str, target: str, words, threshold: float) -> float:
|
||||
m = len(source)
|
||||
n = len(target)
|
||||
a = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
class FuzzyMatcher:
|
||||
def __init__(self, config: FuzzyMatchConfig = None):
|
||||
self.config = config or FuzzyMatchConfig()
|
||||
|
||||
for i in range(m + 1):
|
||||
a[i][0] = i
|
||||
def score(self, source: str, target: str):
|
||||
l_src = len(source)
|
||||
l_tgt = len(target)
|
||||
a: List[List[float]] = [[0] * (l_tgt + 1) for _ in range(l_src + 1)]
|
||||
|
||||
for i in range(n + 1):
|
||||
a[0][i] = i * _insertion_weight
|
||||
for i in range(l_src + 1):
|
||||
a[i][0] = i
|
||||
|
||||
def strip_vowels(s):
|
||||
return re.sub('[aeoiu]', '', s)
|
||||
for i in range(l_tgt + 1):
|
||||
a[0][i] = i * self.config.insertion_weight
|
||||
|
||||
word_match_bonus = 0.1 * max(max(sum(a == b for a, b in zip(source, w)) for w in words),
|
||||
max(sum(a == b for a, b in
|
||||
zip(source[0] + strip_vowels(source[1:]), w[0] + strip_vowels(w[1:]))) for w in
|
||||
words),
|
||||
sum(a == b for a, b in zip(source, ''.join(w[0] for w in words))))
|
||||
def strip_vowels(s):
|
||||
return re.sub('[aeoiu]', '', s)
|
||||
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
a[i][j] = min(a[i - 1][j - 1] + _substitution_weight if source[i - 1] != target[j - 1] else a[i - 1][j - 1],
|
||||
a[i - 1][j] + _deletion_weight,
|
||||
a[i][j - 1] + _insertion_weight)
|
||||
if j == n and (a[i][j] - (m - i) * _insertion_weight - word_match_bonus) > threshold:
|
||||
return 9999
|
||||
words = target.split()
|
||||
word_bonus = min(self.config.word_match_weight * max(sum(a == b for a, b in zip(source, w)) for w in words),
|
||||
self.config.word_match_weight * max(sum(a == b for a, b in
|
||||
zip(source, w[0] + strip_vowels(w[1:]))) for w in
|
||||
words),
|
||||
self.config.acronym_match_weight * sum(
|
||||
a == b for a, b in zip(source, ''.join(w[0] for w in words))))
|
||||
|
||||
return a[m][n] - word_match_bonus
|
||||
def sub_weight_at(n, m):
|
||||
if source[n - 1] != target[m - 1]:
|
||||
return self.config.special_substitution_weights.get(
|
||||
(source[n - 1], target[m - 1]),
|
||||
self.config.default_substitution_weight
|
||||
)
|
||||
else:
|
||||
return self.config.match_weight
|
||||
|
||||
for i_src in range(1, l_src + 1):
|
||||
for i_tgt in range(1, l_tgt + 1):
|
||||
a[i_src][i_tgt] = min(a[i_src - 1][i_tgt - 1] + sub_weight_at(i_src, i_tgt),
|
||||
a[i_src - 1][i_tgt] + self.config.deletion_weight,
|
||||
a[i_src][i_tgt - 1] + self.config.insertion_weight)
|
||||
|
||||
# there are l_scr - i_src source chars remaining
|
||||
# each match removes the insertion weight then adds the match weight
|
||||
# (l_src - i_src) * (self.config.match_weight - self.config.insertion_weight)
|
||||
# is the max difference that can make
|
||||
max_additional_score = ((l_src - i_src) * (self.config.match_weight - self.config.insertion_weight) +
|
||||
word_bonus + self.config.base_score)
|
||||
if i_tgt == l_tgt and (
|
||||
a[i_src][i_tgt] + max_additional_score) > 0 and \
|
||||
(a[i_src][i_tgt - 1] + max_additional_score) > 0:
|
||||
return 1
|
||||
|
||||
return a[l_src][l_tgt] + word_bonus + self.config.base_score
|
||||
|
||||
|
||||
def romanize(s: str) -> Tuple[str, Tuple[str]]:
|
||||
def romanize(s: str) -> str:
|
||||
kks = pykakasi.kakasi()
|
||||
s = re.sub('[\']', '', s)
|
||||
s = re.sub('[・]', ' ', s)
|
||||
s = re.sub('[A-Za-z]+', lambda ele: f' {ele[0]} ', s)
|
||||
s = re.sub('[0-9]+', lambda ele: f' {ele[0]} ', s)
|
||||
s = ' '.join(c['hepburn'].strip().lower() for c in kks.convert(s))
|
||||
s = re.sub(r'[^a-zA-Z0-9_ ]+', '', s)
|
||||
words = tuple(s.split())
|
||||
return ''.join(words), words
|
||||
return ' '.join(s.split())
|
||||
|
Loading…
x
Reference in New Issue
Block a user