Skip to content

Commit

Permalink
Make autocurry() and friends support kw-only and pos-only arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Suor committed Mar 11, 2023
1 parent d4b9f60 commit 5dc4547
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 39 deletions.
81 changes: 48 additions & 33 deletions funcy/_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


ARGS['builtins'] = {
'bool': 'x',
'bool': '*',
'complex': 'real,imag',
'enumerate': 'iterable,start',
'file': 'file-**',
Expand Down Expand Up @@ -91,7 +91,7 @@
}


Spec = namedtuple("Spec", "max_n names req_n req_names kw")
Spec = namedtuple("Spec", "max_n names req_n req_names varkw")


def get_spec(func, _cache={}):
Expand All @@ -108,60 +108,75 @@ def get_spec(func, _cache={}):
req_names = re.findall(r'\w+|\*', required) # a list with dups of *
max_n = len(req_names) + len(optional)
req_n = len(req_names)
spec = Spec(max_n=max_n, names=set(), req_n=req_n, req_names=set(req_names), kw=False)
spec = Spec(max_n=max_n, names=set(), req_n=req_n, req_names=set(req_names), varkw=False)
_cache[func] = spec
return spec
elif isinstance(func, type):
# Old style classes without base
if not hasattr(func, '__init__'):
return Spec(max_n=0, names=set(), req_n=0, req_names=set(), kw=False)
# __init__ inherited from builtin classes
objclass = getattr(func.__init__, '__objclass__', None)
if objclass and objclass is not func:
return get_spec(objclass)
# Introspect constructor and remove self
spec = get_spec(func.__init__)
self_set = set([func.__init__.__code__.co_varnames[0]])
self_set = {func.__init__.__code__.co_varnames[0]}
return spec._replace(max_n=spec.max_n - 1, names=spec.names - self_set,
req_n=spec.req_n - 1, req_names=spec.req_names - self_set)
elif hasattr(func, '__code__'):
return _code_to_spec(func)
else:
# We use signature last to be fully backwards compatible. Also it's slower
try:
defaults_n = len(func.__defaults__)
except (AttributeError, TypeError):
defaults_n = 0
try:
varnames = func.__code__.co_varnames
n = func.__code__.co_argcount
names = set(varnames[:n])
req_n = n - defaults_n
req_names = set(varnames[:req_n])
kw = bool(func.__code__.co_flags & CO_VARKEYWORDS)
# If there are varargs they could be required, but all keywords args can't be
max_n = req_n + 1 if func.__code__.co_flags & CO_VARARGS else n
return Spec(max_n=max_n, names=names, req_n=req_n, req_names=req_names, kw=kw)
except AttributeError:
# We use signature last to be fully backwards compatible. Also it's slower
try:
sig = signature(func)
except (ValueError, TypeError):
raise ValueError('Unable to introspect %s() arguments'
% (getattr(func, '__qualname__', None) or getattr(func, '__name__', func)))
else:
spec = _cache[func] = _sig_to_spec(sig)
return spec
sig = signature(func)
# import ipdb; ipdb.set_trace()
except (ValueError, TypeError):
raise ValueError('Unable to introspect %s() arguments'
% (getattr(func, '__qualname__', None) or getattr(func, '__name__', func)))
else:
spec = _cache[func] = _sig_to_spec(sig)
return spec


def _code_to_spec(func):
code = func.__code__

# Weird function like objects
defaults = getattr(func, '__defaults__', None)
defaults_n = len(defaults) if isinstance(defaults, tuple) else 0

kwdefaults = getattr(func, '__kwdefaults__', None)
if not isinstance(kwdefaults, dict):
kwdefaults = {}

# Python 3.7 and earlier does not have this
posonly_n = getattr(code, 'co_posonlyargcount', 0)

varnames = code.co_varnames
pos_n = code.co_argcount
n = pos_n + code.co_kwonlyargcount
names = set(varnames[posonly_n:n])
req_n = n - defaults_n - len(kwdefaults)
req_names = set(varnames[posonly_n:pos_n - defaults_n] + varnames[pos_n:n]) - set(kwdefaults)
varkw = bool(code.co_flags & CO_VARKEYWORDS)
# If there are varargs they could be required
max_n = n + 1 if code.co_flags & CO_VARARGS else n
return Spec(max_n=max_n, names=names, req_n=req_n, req_names=req_names, varkw=varkw)


def _sig_to_spec(sig):
max_n, names, req_n, req_names, kw = 0, set(), 0, set(), False
max_n, names, req_n, req_names, varkw = 0, set(), 0, set(), False
for name, param in sig.parameters.items():
max_n += 1
if param.kind == param.VAR_KEYWORD:
kw = True
max_n -= 1
varkw = True
elif param.kind == param.VAR_POSITIONAL:
req_n += 1
elif param.kind == param.POSITIONAL_ONLY:
if param.default is param.empty:
req_n += 1
else:
names.add(name)
if param.default is param.empty:
req_n += 1
req_names.add(name)
return Spec(max_n=max_n, names=names, req_n=req_n, req_names=req_names, kw=kw)
return Spec(max_n=max_n, names=names, req_n=req_n, req_names=req_names, varkw=varkw)
3 changes: 1 addition & 2 deletions funcy/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def rcurry(func, n=EMPTY):
return lambda x: rcurry(rpartial(func, x), n - 1)


# TODO: drop `n` in next major release
def autocurry(func, n=EMPTY, _spec=None, _args=(), _kwargs={}):
"""Creates a version of func returning its partial applications
until sufficient arguments are passed."""
Expand All @@ -76,7 +75,7 @@ def autocurried(*a, **kw):
kwargs = _kwargs.copy()
kwargs.update(kw)

if not spec.kw and len(args) + len(kwargs) >= spec.max_n:
if not spec.varkw and len(args) + len(kwargs) >= spec.max_n:
return func(*args, **kwargs)
elif len(args) + len(set(kwargs) & spec.names) >= spec.max_n:
return func(*args, **kwargs)
Expand Down
16 changes: 16 additions & 0 deletions tests/py38_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from funcy.funcs import autocurry


def test_autocurry_posonly():
at = autocurry(lambda a, /, b: (a, b))
assert at(1)(b=2) == (1, 2)
assert at(b=2)(1) == (1, 2)
with pytest.raises(TypeError): at(a=1)(b=2)

at = autocurry(lambda a, /, **kw: (a, kw))
assert at(a=2)(1) == (1, {'a': 2})

at = autocurry(lambda a=1, /, *, b: (a, b))
assert at(b=2) == (1, 2)
assert at(0)(b=3) == (0, 3)
32 changes: 28 additions & 4 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from operator import __add__, __sub__
import sys
import pytest
from whatever import _

Expand Down Expand Up @@ -94,12 +95,37 @@ def test_autocurry_kwargs():
assert at(b=3, c=9)(1) == (1, 3, 9)
with pytest.raises(TypeError): at(b=2, d=3, e=4)(a=1, c=1)


def test_autocurry_kwonly():
at = autocurry(lambda a, *, b: (a, b))
assert at(1, b=2) == (1, 2)
assert at(1)(b=2) == (1, 2)
assert at(b=2)(1) == (1, 2)

at = autocurry(lambda a, *, b=10: (a, b))
assert at(1) == (1, 10)
assert at(b=2)(1) == (1, 2)

at = autocurry(lambda a=1, *, b: (a, b))
assert at(b=2) == (1, 2)
assert at(0)(b=2) == (0, 2)

at = autocurry(lambda *, a=1, b: (a, b))
assert at(b=2) == (1, 2)
assert at(a=0)(b=2) == (0, 2)

# TODO: move this here once we drop Python 3.7
if sys.version_info >= (3, 8):
pytest.register_assert_rewrite("tests.py38_funcs")
from .py38_funcs import test_autocurry_posonly # noqa


def test_autocurry_builtin():
assert autocurry(complex)(imag=1)(0) == 1j
assert autocurry(lmap)(_ + 1)([1, 2]) == [2, 3]
assert autocurry(int)(base=12)('100') == 144
# Only works in newer Pythons, relies on inspect.signature()
# assert autocurry(str.split)(sep='_')('a_1') == ['a', '1']
if sys.version_info >= (3, 7):
assert autocurry(str.split)(sep='_')('a_1') == ['a', '1']

def test_autocurry_hard():
def required_star(f, *seqs):
Expand Down Expand Up @@ -130,15 +156,13 @@ class I(int): pass
assert autocurry(int)(base=12)('100') == 144

def test_autocurry_docstring():

@autocurry
def f(a, b):
'docstring'

assert f.__doc__ == 'docstring'



def test_compose():
double = _ * 2
inc = _ + 1
Expand Down

0 comments on commit 5dc4547

Please sign in to comment.