diff --git a/miyu_bot/commands/common/argument_parsing.py b/miyu_bot/commands/common/argument_parsing.py index 7c962ba..9865f2d 100644 --- a/miyu_bot/commands/common/argument_parsing.py +++ b/miyu_bot/commands/common/argument_parsing.py @@ -3,17 +3,18 @@ import re # https://stackoverflow.com/questions/249791/regex-for-quoted-string-with-escaping-quotes # https://stackoverflow.com/questions/21105360/regex-find-comma-not-inside-quotes from collections import namedtuple -from typing import Dict, List, Optional, Container, Any, Union, Callable +from typing import Dict, List, Optional, Container, Any, Union, Callable, Set, Iterable +# The ` ?` is just so it matches the space after during the replace with blank so there's no double spaces _param_re = re.compile( - r'(([a-zA-Z]+)(!=|>=|<=|>|<|==|=)(("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)(,("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+))*))') -# The intention of having both = and == is that they might have different behavior -# What that means depends on the usage + r'(([a-zA-Z]+)(!=|>=|<=|>|<|==|=)(("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)(,("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+))*)) ?') +# The intention of having both = and == is that they might have different behavior. +# What that means depends on the usage. _param_operator_re = re.compile(r'!=|==|=|>|<|>=|<=') _param_argument_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)') _param_string_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\')') -_param_re_with_post_space = re.compile( - r'([a-zA-Z]+)(!=|==|=|>|<|>=|<=)(("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)(,("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+))*) ?') + +_tag_re = re.compile(r'(\$[^\s]+) ?') NamedArgument = namedtuple('NamedArgument', 'name operator value') ArgumentValue = namedtuple('ArgumentValue', 'value operator') @@ -30,13 +31,17 @@ def _parse_named_argument(arg): def parse_arguments(arg): named_arguments_parsed = [_parse_named_argument(na[0]) for na in _param_re.findall(arg)] - text_argument = _param_re_with_post_space.sub('', arg) + arg = _param_re.sub('', arg) + # Technically, the order (named arguments then tags) + # matters because otherwise a fake tag could appear as a value to a named argument + tags = [t[1:] for t in _tag_re.findall(arg)] + arg = _tag_re.sub('', arg) named_arguments = {} for na in named_arguments_parsed: if na.name not in named_arguments: named_arguments[na.name] = [] named_arguments[na.name].append(ArgumentValue(na.value, na.operator)) - return ParsedArguments(text_argument.strip(), named_arguments) + return ParsedArguments(arg.strip(), set(tags), named_arguments) class ArgumentError(Exception): @@ -45,12 +50,35 @@ class ArgumentError(Exception): class ParsedArguments: text_argument: str + tag_arguments: Set[str] named_arguments: Dict[str, List[ArgumentValue]] - def __init__(self, text, named_arguments): + def __init__(self, text, tags, named_arguments): self.text_argument = text + self.tag_arguments = tags self.named_arguments = named_arguments - self.used = set() + self.used_named_arguments = set() + self.used_tags = set() + + def tag(self, name: str): + if name in self.tag_arguments: + self.used_tags.add(name) + return True + return False + + def tags(self, names: Optional[Iterable[str]] = None, aliases: Optional[Dict[str, str]] = None): + results = set() + if names is not None: + for name in names: + if name in self.tag_arguments: + results.add(name) + self.used_tags.add(name) + if aliases is not None: + for alias, value in aliases.items(): + if alias in self.tag_arguments: + results.add(value) + self.used_tags.add(alias) + return results def single(self, names: Union[List[str], str], default: Any = None, allowed_operators: Optional[Container] = None, is_list=False, numeric=False, converter: Union[dict, Callable] = lambda n: n): @@ -61,7 +89,7 @@ class ParsedArguments: if not isinstance(names, list): names = [names] for name in names: - self.used.add(name) + self.used_named_arguments.add(name) name = f'{names[0]} ({", ".join(names[1:])})' if len(names) > 1 else names[0] value = [arg for args in (self.named_arguments.get(name) for name in names) if args for arg in args] if not value: @@ -90,7 +118,8 @@ class ParsedArguments: value = ArgumentValue(value.value[0], value.operator) return value - def repeatable(self, names: Union[List[str], str], default: Any = None, allowed_operators: Optional[Container] = None, + def repeatable(self, names: Union[List[str], str], default: Any = None, + allowed_operators: Optional[Container] = None, is_list=False, numeric=False, converter: Union[dict, Callable] = lambda n: n): if allowed_operators is None: allowed_operators = {'>', '<', '>=', '<=', '!=', '==', '='} @@ -101,7 +130,7 @@ class ParsedArguments: if not isinstance(names, list): names = [names] for name in names: - self.used.add(name) + self.used_named_arguments.add(name) name = f'{names[0]} ({", ".join(names[1:])})' if len(names) > 1 else names[0] values = [arg for args in (self.named_arguments.get(name) for name in names) if args for arg in args] if not values: @@ -128,15 +157,24 @@ class ParsedArguments: return values def has_unused(self): - return any(name not in self.used for name in self.named_arguments.keys()) + return self.has_unused_named_arguments() or self.has_unused_tags() + + def has_unused_named_arguments(self): + return any(name not in self.used_named_arguments for name in self.named_arguments.keys()) + + def has_unused_tags(self): + return any(t not in self.used_tags for t in self.tag_arguments) def require_all_arguments_used(self): def quote(s): return f'"{s}"' - if self.has_unused(): + if self.has_unused_named_arguments(): + raise ArgumentError( + f'Unknown arguments with names {", ".join(quote(v) for v in self.named_arguments.keys() if v not in self.used_named_arguments)}.') + if self.has_unused_tags(): raise ArgumentError( - f'Unkown arguments with names {", ".join(quote(v) for v in self.named_arguments.keys() if v not in self.used)}.') + f'Unknown tags {", ".join(quote(v) for v in self.tag_arguments if v not in self.used_tags)}.') _operators = { @@ -169,5 +207,6 @@ def list_operator_for(operator: str): if __name__ == '__main__': - a = (parse_arguments(r'sort=default rating>=13.5 a name="a",b," asf,ds ",\'sdf\',dsf')) + a = ( + parse_arguments(r'sort=default df rating>=13.5,$asd a fds $foo $bar name="a",b," asf,ds ",\'sdf\',dsf $foobar')) print(a)