add repeatable arguments to argument parser

pull/1/head
qwewqa 4 years ago
parent 3850820794
commit 3657f8bbaa
  1. 83
      miyu_bot/commands/common/argument_parsing.py

@ -7,6 +7,8 @@ from typing import Dict, List, Optional, Container, Any, Union, Callable
_param_re = re.compile( _param_re = re.compile(
r'(([a-zA-Z]+)(!=|>=|<=|>|<|==|=)(("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)(,("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+))*))') r'(([a-zA-Z]+)(!=|>=|<=|>|<|==|=)(("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)(,("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+))*))')
# The intention of having both = and == is that they might have different behavior
# What that means depends on the usage
_param_operator_re = re.compile(r'!=|==|=|>|<|>=|<=') _param_operator_re = re.compile(r'!=|==|=|>|<|>=|<=')
_param_argument_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)') _param_argument_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\s]+)')
_param_string_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\')') _param_string_re = re.compile(r'("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\')')
@ -50,15 +52,19 @@ class ParsedArguments:
self.named_arguments = named_arguments self.named_arguments = named_arguments
self.used = set() self.used = set()
def single(self, name: str, default: Any = None, allowed_operators: Optional[Container] = None, def single(self, names: Union[List[str], str], default: Any = None, allowed_operators: Optional[Container] = None,
is_list=False, numeric=False, converter: Union[dict, Callable] = lambda n: n): is_list=False, numeric=False, converter: Union[dict, Callable] = lambda n: n):
if allowed_operators is None: if allowed_operators is None:
allowed_operators = {'>', '<', '>=', '<=', '!=', '==', '='} allowed_operators = {'>', '<', '>=', '<=', '!=', '==', '='}
if not isinstance(default, tuple): if not isinstance(default, tuple) and default is not None:
default = ArgumentValue(default, '=') default = ArgumentValue(default, '=')
self.used.add(name) if not isinstance(names, list):
value = self.named_arguments.get(name) names = [names]
if value is None: for name in names:
self.used.add(name)
name = f'{names[0]} ({", ".join(names[1:])})' if len(names) > 1 else names[0]
value = [arg for args in (self.named_arguments.get(name) for name in names) if args for arg in args]
if not value:
return default return default
if len(value) != 1: if len(value) != 1:
raise ArgumentError(f'Expected only one value for parameter "{name}".') raise ArgumentError(f'Expected only one value for parameter "{name}".')
@ -84,17 +90,84 @@ class ParsedArguments:
value = ArgumentValue(value.value[0], value.operator) value = ArgumentValue(value.value[0], value.operator)
return value return value
def repeatable(self, names: Union[List[str], str], default: Any = None, allowed_operators: Optional[Container] = None,
is_list=False, numeric=False, converter: Union[dict, Callable] = lambda n: n):
if allowed_operators is None:
allowed_operators = {'>', '<', '>=', '<=', '!=', '==', '='}
if not isinstance(default, tuple) and default is not None:
default = [ArgumentValue(default, '=')]
if default is None:
default = []
if not isinstance(names, list):
names = [names]
for name in names:
self.used.add(name)
name = f'{names[0]} ({", ".join(names[1:])})' if len(names) > 1 else names[0]
values = [arg for args in (self.named_arguments.get(name) for name in names) if args for arg in args]
if not values:
return default
if any(value.operator not in allowed_operators for value in values):
raise ArgumentError(
f'Allowed operators for parameter "{name}" are {", ".join(str(o) for o in allowed_operators)}.')
if numeric:
try:
values = [ArgumentValue([float(v) for v in value.value], value.operator) for value in values]
except ValueError:
raise ArgumentError(f'Expected numerical arguments for parameter "{name}".')
try:
if isinstance(converter, dict):
values = [ArgumentValue([converter[v] for v in value.value], value.operator) for value in values]
else:
values = [ArgumentValue([converter(v) for v in value.value], value.operator) for value in values]
except Exception:
raise ArgumentError(f'Invalid value for parameter "{name}".')
if not is_list:
if any(len(value.value) != 1 for value in values):
raise ArgumentError(f'List not allowed for parameter "{name}".')
values = [ArgumentValue(value.value[0], value.operator) for value in values]
return values
def has_unused(self): def has_unused(self):
return any(name not in self.used for name in self.named_arguments.keys()) return any(name not in self.used for name in self.named_arguments.keys())
def require_all_arguments_used(self): def require_all_arguments_used(self):
def quote(s): def quote(s):
return f'"{s}"' return f'"{s}"'
if self.has_unused(): if self.has_unused():
raise ArgumentError( raise ArgumentError(
f'Unkown arguments with names {", ".join(quote(v) for v in self.named_arguments.keys() if v not in self.used)}.') f'Unkown arguments with names {", ".join(quote(v) for v in self.named_arguments.keys() if v not in self.used)}.')
_operators = {
'=': lambda a, b: a == b,
'==': lambda a, b: a == b,
'!=': lambda a, b: a != b,
'>': lambda a, b: a > b,
'<': lambda a, b: a < b,
'>=': lambda a, b: a >= b,
'<=': lambda a, b: a <= b,
}
_list_operators = {
'=': lambda a, b: any(a == v for v in b),
'==': lambda a, b: all(a == v for v in b),
'!=': lambda a, b: all(a != v for v in b),
'>': lambda a, b: all(a > v for v in b),
'<': lambda a, b: all(a < v for v in b),
'>=': lambda a, b: all(a >= v for v in b),
'<=': lambda a, b: all(a <= v for v in b),
}
def operator_for(operator: str):
return _operators[operator]
def list_operator_for(operator: str):
return _list_operators[operator]
if __name__ == '__main__': if __name__ == '__main__':
a = (parse_arguments(r'sort=default rating>=13.5 a name="a",b," asf,ds ",\'sdf\',dsf')) a = (parse_arguments(r'sort=default rating>=13.5 a name="a",b," asf,ds ",\'sdf\',dsf'))
print(a) print(a)

Loading…
Cancel
Save