From 98d6e374056ee75f3f69099d0e644d835e003c42 Mon Sep 17 00:00:00 2001
From: Zhufeng Pan
Date: Thu, 24 Aug 2023 10:48:54 -0700
Subject: [PATCH] Add initial support for positional args support within
`fdl.Config`.
To access positional args:
```python
v = config[:] # the full list
v = config[-1] # normal index
v = config[:3] # slice index
```
To modify positional args:
```python
config[:] = [1, 2] # assign to a new list
config[1] = 2
```
PiperOrigin-RevId: 559802927
---
fiddle/__init__.py | 1 +
fiddle/_src/building.py | 42 +++-
fiddle/_src/config.py | 277 ++++++++++++++++++++++---
fiddle/_src/config_test.py | 214 +++++++++++++++++--
fiddle/_src/signatures.py | 72 +++++--
fiddle/config.py | 1 +
fiddle/examples/colabs/basic_api.ipynb | 45 +---
7 files changed, 525 insertions(+), 127 deletions(-)
diff --git a/fiddle/__init__.py b/fiddle/__init__.py
index 33146324..0ab98d0f 100644
--- a/fiddle/__init__.py
+++ b/fiddle/__init__.py
@@ -26,6 +26,7 @@
from fiddle._src.config import NO_VALUE
from fiddle._src.config import ordered_arguments
from fiddle._src.config import update_callable
+from fiddle._src.config import VARARGS
from fiddle._src.materialize import materialize_defaults
from fiddle._src.partial import ArgFactory
from fiddle._src.partial import Partial
diff --git a/fiddle/_src/building.py b/fiddle/_src/building.py
index 6212a6fb..9324bdd4 100644
--- a/fiddle/_src/building.py
+++ b/fiddle/_src/building.py
@@ -19,7 +19,7 @@
import functools
import logging
import threading
-from typing import Any, Callable, Dict, TypeVar, overload
+from typing import Any, Callable, Dict, Sequence, TypeVar, overload
from fiddle._src import config as config_lib
from fiddle._src import daglish
@@ -60,8 +60,12 @@ def _format_arg(arg: Any) -> str:
return f''
-def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,
- arguments: Dict[str, Any]) -> str:
+def _make_message(
+ current_path: daglish.Path,
+ buildable: config_lib.Buildable,
+ args: Sequence[Any],
+ kwargs: Dict[str, Any],
+) -> str:
"""Returns Fiddle-related debugging information for an exception."""
path_str = '' + daglish.path_str(current_path)
fn_or_cls = config_lib.get_callable(buildable)
@@ -69,11 +73,15 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,
fn_or_cls_name = fn_or_cls.__qualname__
except AttributeError:
fn_or_cls_name = str(fn_or_cls) # callable instances, etc.
+ args_str = ', '.join(f'{_format_arg(value)}' for value in args)
kwargs_str = ', '.join(
- f'{name}={_format_arg(value)}' for name, value in arguments.items())
+ f'{name}={_format_arg(value)}' for name, value in kwargs.items()
+ )
tag_information = ''
- bound_args = buildable.__signature_info__.signature.bind_partial(**arguments)
+ bound_args = buildable.__signature_info__.signature.bind_partial(
+ *args, **kwargs
+ )
bound_args.apply_defaults()
unset_arg_tags = []
for param in buildable.__signature_info__.parameters:
@@ -90,7 +98,8 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,
return (
'\n\nFiddle context: failed to construct or call '
f'{fn_or_cls_name} at {path_str} '
- f'with arguments ({kwargs_str}){tag_information}'
+ f'with positional arguments: ({args_str}), '
+ f'keyword arguments: ({kwargs_str}){tag_information}.'
)
@@ -100,10 +109,25 @@ def call_buildable(
*,
current_path: daglish.Path,
) -> Any:
- make_message = functools.partial(_make_message, current_path, buildable,
- arguments)
+ """Prepare positional arguments and actually build the buildable."""
+ positional_only, keyword_or_positional, var_positional = (
+ buildable.__signature_info__.get_positional_names()
+ )
+ positional_arguments = []
+ for name in positional_only:
+ if name in arguments:
+ positional_arguments.append(arguments.pop(name))
+ if var_positional is not None:
+ for name in keyword_or_positional:
+ if name in arguments:
+ positional_arguments.append(arguments.pop(name))
+ if var_positional in arguments:
+ positional_arguments.extend(arguments.pop(var_positional))
+ make_message = functools.partial(
+ _make_message, current_path, buildable, positional_arguments, arguments
+ )
with reraised_exception.try_with_lazy_message(make_message):
- return buildable.__build__(**arguments)
+ return buildable.__build__(*positional_arguments, **arguments)
# Define typing overload for `build(Partial[T])`
diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py
index 44d2b6d6..166b4a29 100644
--- a/fiddle/_src/config.py
+++ b/fiddle/_src/config.py
@@ -22,6 +22,7 @@
import copy
import dataclasses
import functools
+import inspect
import types
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
@@ -58,6 +59,9 @@ def __copy__(self):
# None or other commonly-used sentinel.
_UNSET_SENTINEL = object()
+# Unique object instance that represents the index where varadic positional
+# arguments start for a Buildable.
+VARARGS = object()
_defaults_aware_traverser_registry = daglish.NodeTraverserRegistry(
use_fallback=True
@@ -247,7 +251,7 @@ def __init__(
)
for name, value in arguments.items():
- setattr(self, name, value)
+ self._setattr(name, value, allow_postional_argument=True)
for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items():
self.__argument_tags__[name].update(tags)
@@ -258,6 +262,7 @@ def __init__(
def __init_callable__(
self, fn_or_cls: Union['Buildable[T]', TypeOrCallableProducingT[T]]
) -> None:
+ """Save information on `fn_or_cls` to the `Buildable`."""
if isinstance(fn_or_cls, Buildable):
raise ValueError(
'Using the Buildable constructor to convert a buildable to a new '
@@ -273,9 +278,11 @@ def __init_callable__(
super().__setattr__('__fn_or_cls__', fn_or_cls)
super().__setattr__('__arguments__', {})
signature = signatures.get_signature(fn_or_cls)
+ # Several attributes are computed automatically by SignatureInfo during
+ # `__post_init__`.
super().__setattr__(
'__signature_info__',
- signatures.SignatureInfo(signature),
+ signatures.SignatureInfo(signature=signature),
)
def __init_subclass__(cls):
@@ -312,6 +319,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:
def __getattr__(self, name: str):
"""Get parameter with given ``name``."""
value = self.__arguments__.get(name, _UNSET_SENTINEL)
+ param = self.__signature_info__.parameters.get(name)
+ if param is not None and (
+ param.kind in (param.POSITIONAL_ONLY, param.VAR_POSITIONAL)
+ ):
+ raise AttributeError(
+ 'Cannot access positional-only or variadic positional arguments '
+ f'{name} on {self!r} by attributes.'
+ )
if value is not _UNSET_SENTINEL:
return value
@@ -323,7 +338,6 @@ def __getattr__(self, name: str):
+ f'{self.__fn_or_cls__.__qualname__}.{name} '
+ 'since it uses a default_factory.'
)
- param = self.__signature_info__.parameters.get(name)
if param is not None and param.default is not param.empty:
return param.default
msg = f"No parameter '{name}' has been set on {self!r}."
@@ -340,10 +354,15 @@ def __getattr__(self, name: str):
)
raise AttributeError(msg)
- def __setattr__(self, name: str, value: Any):
- """Sets parameter ``name`` to ``value``."""
-
- self.__signature_info__.validate_param_name(name, self.__fn_or_cls__)
+ def _setattr(
+ self, name: str, value: Any, allow_postional_argument: bool = False
+ ):
+ """The __setattr__ implementation."""
+ self.__signature_info__.validate_param_name(
+ name,
+ self.__fn_or_cls__,
+ allow_postional_argument=allow_postional_argument,
+ )
if isinstance(value, TaggedValueCls):
tags = value.__argument_tags__.get('value', ())
@@ -360,6 +379,10 @@ def __setattr__(self, name: str, value: Any):
self.__arguments__[name] = value
self.__argument_history__.add_new_value(name, value)
+ def __setattr__(self, name: str, value: Any):
+ """Sets parameter ``name`` to ``value``."""
+ self._setattr(name, value)
+
def __delattr__(self, name):
"""Unsets parameter ``name``."""
try:
@@ -369,6 +392,90 @@ def __delattr__(self, name):
err = AttributeError(f"No parameter '{name}' has been set on {self!r}")
raise err from None
+ def _get_all_positional_args(self):
+ """Get a full list of positional arguments."""
+ positional_only, keyword_or_positional, var_positional = (
+ self.__signature_info__.get_positional_names()
+ )
+ positional_arguments = []
+ for name in positional_only:
+ positional_arguments.append(self.__arguments__[name])
+ if var_positional:
+ for name in keyword_or_positional:
+ positional_arguments.append(self.__arguments__[name])
+ positional_arguments += self.__arguments__.get(var_positional, [])
+ return positional_arguments
+
+ def _replace_varargs_handle(self, key):
+ """Replace VARARGS handle in index key if exists."""
+ positional_only, keyword_or_positional, _ = (
+ self.__signature_info__.get_positional_names()
+ )
+ start = len(positional_only) + len(keyword_or_positional)
+ if isinstance(key, slice) and key.start is VARARGS:
+ return slice(start, key.stop, key.step)
+ elif key is VARARGS:
+ return start
+ return key
+
+ def __getitem__(self, key: Any):
+ """Get positional arguments by index."""
+ key = self._replace_varargs_handle(key)
+ all_positional_args = self._get_all_positional_args()
+ return all_positional_args[key]
+
+ def __setitem__(self, key: Any, value: Any):
+ """Set positional arguments by index."""
+ key = self._replace_varargs_handle(key)
+ assert isinstance(
+ key, (int, slice)
+ ), f'Key of __setitem__ must be an int or slice, got {key}.'
+ positional_only, keyword_or_positional, var_positional = (
+ self.__signature_info__.get_positional_names()
+ )
+ positional_names = positional_only
+ if var_positional:
+ positional_names += keyword_or_positional
+ old_positional_args = self._get_all_positional_args()
+ # Set positional arguments values using a comparison approach.
+ # Because setting values directly will lead to very complex logics due to
+ # various indices patterns, as well as the case where key is a slice but
+ # the value is not a sequence object.
+ new_positional_args = copy.deepcopy(old_positional_args)
+ new_positional_args[key] = value
+
+ # Handle non-variadic positional arguments
+ for index, name in enumerate(positional_names):
+ if index < len(new_positional_args):
+ if old_positional_args[index] != new_positional_args[index]:
+ new_value = new_positional_args[index]
+ self.__arguments__[name] = new_value
+ self.__argument_history__.add_new_value(name, new_value)
+ else:
+ del self.__arguments__[name]
+ self.__argument_history__.add_deleted_value(name)
+
+ # Handle variadic positional arguments
+ if var_positional is None:
+ if len(new_positional_args) > len(positional_names):
+ raise ValueError(
+ 'Too many arguments are provided. There are only '
+ f'{len(positional_names)} positional arguments but '
+ f'{len(new_positional_args)} are provided to '
+ f'{self.__fn_or_cls__.__qualname__}.'
+ )
+ else:
+ if len(new_positional_args) <= len(positional_names):
+ del self.__arguments__[var_positional]
+ self.__argument_history__.add_deleted_value(var_positional)
+ else:
+ new_var_positional_arg = new_positional_args[len(positional_names) :]
+ if new_var_positional_arg != self.__arguments__.get(var_positional, []):
+ self.__arguments__[var_positional] = new_var_positional_arg
+ self.__argument_history__.add_new_value(
+ var_positional, new_var_positional_arg
+ )
+
def __dir__(self) -> Collection[str]:
"""Provide a useful list of attribute names, optimized for Jupyter/Colab.
@@ -488,9 +595,7 @@ def __getstate__(self):
Dict of serialized state.
"""
result = dict(self.__dict__)
- result['__signature_info__'] = signatures.SignatureInfo( # pytype: disable=wrong-arg-types
- None, result['__signature_info__'].has_var_keyword
- )
+ result['__signature_info__'] = signatures.SignatureInfo(None) # pytype: disable=wrong-arg-types
return result
def __setstate__(self, state) -> None:
@@ -503,8 +608,10 @@ def __setstate__(self, state) -> None:
"""
self.__dict__.update(state) # Support unpickle.
if self.__signature_info__.signature is None:
- self.__signature_info__.signature = signatures.get_signature(
- self.__fn_or_cls__
+ signature = signatures.get_signature(self.__fn_or_cls__)
+ super().__setattr__(
+ '__signature_info__',
+ signatures.SignatureInfo(signature=signature),
)
@@ -637,11 +744,123 @@ def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
return False
+def _update_positional_args(
+ buildable: Buildable,
+ original_signature: inspect.Signature,
+ new_signature: inspect.Signature,
+ drop_invalid_args: bool = False,
+) -> None:
+ """Update positional arguments in place.
+
+ The naive approach to update positional arguments when changing the callable
+ is to update each individual argument. However, the mapping problem (for
+ example, some positional arugments may have differnt names now, or need to map
+ *args to concrete positional arguments) is very challenging, because there
+ are mulitple possible conditions depending on if the origial and the new
+ callable have varadic positional arugments.
+
+ This method adopts the approach that first builds a full list of all
+ positional arguments, and then try to map the argument list accroding to the
+ signature of new callable.
+
+ Args:
+ buildable: A ``Buildable`` (e.g. a ``fdl.Config``) to update.
+ original_signature: Signature of the original callable.
+ new_signature: Signature of the new callable.
+ drop_invalid_args: If True, arguments that don't exist in the new callable
+ will be removed from buildable. If False, raise an exception for such
+ arguments.
+
+ Raises:
+ TypeError: If fails to match the origial positional arguments to the new
+ callable.
+ """
+ positional_argument_names = []
+ positional_argument_values = []
+ keyword_or_positional_argument_names = []
+ keyword_or_positional_argument_values = []
+ var_positional_name = None
+
+ for param in original_signature.parameters.values():
+ if param.name not in buildable.__arguments__:
+ break
+ if param.kind == param.POSITIONAL_ONLY:
+ positional_argument_names.append(param.name)
+ value = buildable.__arguments__[param.name]
+ positional_argument_values.append(value)
+ if param.kind == param.POSITIONAL_OR_KEYWORD:
+ keyword_or_positional_argument_names.append(param.name)
+ value = buildable.__arguments__[param.name]
+ keyword_or_positional_argument_values.append(value)
+ if param.kind == param.VAR_POSITIONAL:
+ var_positional_name = param.name
+ values = buildable.__arguments__[param.name]
+ if values:
+ # if *args exist, keyword-or-positional arguments will become
+ # positional arguments.
+ positional_argument_names.extend(keyword_or_positional_argument_names)
+ positional_argument_values.extend(keyword_or_positional_argument_values)
+ positional_argument_names.extend([None for _ in values])
+ positional_argument_values.extend(values)
+ break
+
+ for index, param in enumerate(new_signature.parameters.values()):
+ if index >= len(positional_argument_values):
+ break
+ new_name = param.name
+ old_name = positional_argument_names[index]
+ if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
+ new_value = positional_argument_values[index]
+ if old_name == new_name and (
+ buildable.__arguments__[old_name] == new_value
+ ):
+ continue
+ buildable.__arguments__[new_name] = new_value
+ buildable.__argument_history__.add_new_value(new_name, new_value)
+ if old_name != new_name and old_name is not None:
+ del buildable.__arguments__[old_name]
+ buildable.__argument_history__.add_deleted_value(old_name)
+ if param.kind == param.VAR_POSITIONAL:
+ # All positional arguments will be matched to *args, so delte all
+ # remaining positional-only and positional-or-keyword arguments in the
+ # current `__arguments__` dict.
+ new_value = positional_argument_values[index:]
+ for name in positional_argument_names[index:]:
+ if name and name in buildable.__arguments__:
+ del buildable.__arguments__[name]
+ buildable.__argument_history__.add_deleted_value(name)
+ if var_positional_name and var_positional_name in buildable.__arguments__:
+ if new_name != var_positional_name:
+ del buildable.__arguments__[var_positional_name]
+ buildable.__argument_history__.add_deleted_value(var_positional_name)
+ else:
+ # Varadic positional arguments have the same name and value
+ if new_value == buildable.__arguments__[var_positional_name]:
+ break
+ buildable.__arguments__[new_name] = new_value
+ buildable.__argument_history__.add_new_value(param.kind, new_value)
+ # All positional arguments have been matched, exit the for loop.
+ break
+ if param.kind in (param.KEYWORD_ONLY, param.VAR_KEYWORD):
+ if drop_invalid_args:
+ raise NotImplementedError(
+ 'Drop invalid positional arguments are not supported yet.'
+ )
+ else:
+ raise TypeError(
+ f'Fail to match buildable {buildable} from signature'
+ f'{original_signature} to {new_signature}.'
+ )
+ if var_positional_name in buildable.__arguments__:
+ del buildable.__arguments__[var_positional_name]
+ buildable.__argument_history__.add_deleted_value(var_positional_name)
+
+
def update_callable(
buildable: Buildable,
new_callable: TypeOrCallableProducingT,
drop_invalid_args: bool = False,
-):
+) -> None:
"""Updates ``config`` to build ``new_callable`` instead.
When extending a base configuration, it can often be useful to swap one class
@@ -666,24 +885,26 @@ def update_callable(
#
# Note: can't call `setattr` on all the args to validate them, because that
# will result in duplicate history entries.
- original_args = buildable.__arguments__
- signature = signatures.get_signature(new_callable)
- if any(
- param.kind == param.VAR_POSITIONAL
- for param in signature.parameters.values()
- ):
- raise NotImplementedError(
- 'Variable positional arguments (aka `*args`) not supported.'
- )
- signature_info = signatures.SignatureInfo(signature)
- object.__setattr__(
+ new_signature = signatures.get_signature(new_callable)
+ # Update the signature early so that we can set arguments by position.
+ # Otherwise, parameter validation logics would complain about argument
+ # name not exists.
+ object.__setattr__(buildable, '__signature__', new_signature)
+ new_signature_info = signatures.SignatureInfo(signature=new_signature)
+ original_signature_info = buildable.__signature_info__
+ object.__setattr__(buildable, '__signature_info__', new_signature_info)
+ _update_positional_args(
buildable,
- '__signature_info__',
- signature_info,
+ original_signature_info.signature,
+ new_signature_info.signature,
+ drop_invalid_args,
)
- if not signature_info.has_var_keyword:
+
+ if not new_signature_info.has_var_keyword:
invalid_args = [
- arg for arg in original_args.keys() if arg not in signature.parameters
+ arg
+ for arg in buildable.__arguments__.keys()
+ if arg not in new_signature.parameters
]
if invalid_args:
if drop_invalid_args:
diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py
index d03e21da..39e3e5cb 100644
--- a/fiddle/_src/config_test.py
+++ b/fiddle/_src/config_test.py
@@ -72,6 +72,14 @@ def fn_with_var_args_and_kwargs(arg1, *args, kwarg1=None, **kwargs): # pylint:
return locals()
+def fn_with_args_and_kwargs_only(*args, **kwargs):
+ return args, kwargs
+
+
+def fn_with_position_args(a, b, /, c=1, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+
def make_typed_config() -> fdl.Config[SampleClass]:
"""Helper function which returns a fdl.Config whose type is known."""
return fdl.Config(SampleClass, arg1=1, arg2=2)
@@ -195,14 +203,109 @@ def test_config_for_functions_with_var_args_and_kwargs(self):
fn_args = fdl.build(fn_config)
self.assertEqual(fn_args['arg1'], 'arg1')
- fn_config.args = 'kwarg_called_arg'
fn_config.kwargs = 'kwarg_called_kwarg'
fn_args = fdl.build(fn_config)
self.assertEqual(fn_args['kwargs'], {
- 'args': 'kwarg_called_arg',
'kwargs': 'kwarg_called_kwarg'
})
+ def test_postional_args_access(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ self.assertEqual(fn_config[0], 1)
+ self.assertEqual(fn_config[-1], 5)
+ self.assertSequenceEqual(fn_config[3:], [4, 5])
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3, 4, 5])
+
+ def test_positional_args_modification(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ fn_config[0] = 0
+ self.assertSequenceEqual(fn_config[:], [0, 2, 3, 4, 5])
+ fn_config[:3] = [5, 6, 7]
+ self.assertSequenceEqual(fn_config[:], [5, 6, 7, 4, 5])
+ fn_config[3:] = [5, 6, 7]
+ self.assertSequenceEqual(fn_config[:], [5, 6, 7, 5, 6, 7])
+ fn_config[:] = [1, 2, 3]
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3])
+ fn_config[:] += [4, 5]
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3, 4, 5])
+
+ def test_varargs_index_handle(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ with self.subTest('access'):
+ self.assertEqual(fn_config[fdl.VARARGS], 4)
+ self.assertSequenceEqual(fn_config[fdl.VARARGS :], [4, 5])
+ with self.subTest('modify'):
+ fn_config[fdl.VARARGS :] = []
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3])
+ fn_config[fdl.VARARGS :] = [7, 8, 9]
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3, 7, 8, 9])
+ fn_config[fdl.VARARGS] = 0
+ self.assertSequenceEqual(fn_config[:], [1, 2, 3, 0, 8, 9])
+
+ def test_modification_when_var_args_are_empty(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3)
+ self.assertEmpty(fn_config[fdl.VARARGS :])
+ fn_config[:] = ['a', 'b', 'c', 'd', 'e']
+ self.assertSequenceEqual(fn_config[:], ['a', 'b', 'c', 'd', 'e'])
+
+ def test_positional_args_direct_access_is_forbidden(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ with self.assertRaisesRegex(
+ AttributeError,
+ 'Cannot access positional-only or variadic positional arguments',
+ ):
+ _ = fn_config.args
+
+ with self.assertRaisesRegex(
+ AttributeError,
+ 'Cannot access positional-only or variadic positional arguments',
+ ):
+ _ = fn_config.a
+
+ def test_positional_args_direct_modification_is_forbidden(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ with self.assertRaisesRegex(
+ AttributeError, 'Cannot access VAR_POSITIONAL parameter'
+ ):
+ fn_config.args = [0]
+
+ with self.assertRaisesRegex(
+ AttributeError, 'Cannot access POSITIONAL_ONLY parameter'
+ ):
+ fn_config.a = 0
+
+ def test_positional_or_keyword_args_have_consistent_values(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ fn_config[2] = 'arg-c'
+ self.assertEqual(fn_config.c, 'arg-c')
+ fn_config.c = 'index-2'
+ self.assertEqual(fn_config[2], 'index-2')
+
+ def test_index_out_of_range(self):
+ fn_config = fdl.Config(fn_with_var_args, 1, 2)
+ self.assertLen(fn_config[:], 2)
+ with self.assertRaisesRegex(
+ IndexError, 'list assignment index out of range'
+ ):
+ fn_config[2] = 'index-2'
+ with self.assertRaisesRegex(IndexError, 'list index out of range'):
+ _ = fn_config[2]
+
+ def test_args_config_shallow_copy(self):
+ fn_config = fdl.Config(fn_with_var_args, 1, 2)
+ self.assertLen(fn_config[:], 2)
+ a_copy = fn_config[:]
+ a_copy.append('3')
+ self.assertLen(fn_config[:], 2)
+ self.assertLen(a_copy, 3)
+
+ def test_args_config_build(self):
+ fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5)
+ self.assertEqual(
+ fdl.build(fn_config),
+ {'a': 1, 'b': 2, 'c': 3, 'args': (4, 5)},
+ )
+
def test_config_for_dicts(self):
dict_config = fdl.Config(dict, a=1, b=2)
dict_config.c = 3
@@ -616,18 +719,6 @@ def test_nonexistent_parameter_error(self):
with self.assertRaisesRegex(TypeError, expected_msg):
class_config.nonexistent_arg = 'error!'
- def test_nonexistent_var_args_parameter_error(self):
- fn_config = fdl.Config(fn_with_var_args)
- expected_msg = (r'Variadic arguments \(e.g. \*args\) are not supported\.')
- with self.assertRaisesRegex(TypeError, expected_msg):
- fn_config.args = (1, 2, 3)
-
- def test_unsupported_var_args_error(self):
- expected_msg = (r'Variable positional arguments \(aka `\*args`\) not '
- r'supported\.')
- with self.assertRaisesRegex(NotImplementedError, expected_msg):
- fdl.Config(fn_with_var_args, 1, 2, 3)
-
def test_build_inside_build(self):
def inner_build(x: int) -> str:
@@ -737,15 +828,21 @@ def test_build_nested_structure(self):
def test_build_raises_nice_error_too_few_args(self):
cfg = fdl.Config(basic_fn, fdl.Config(SampleClass, 1), 2)
- with self.assertRaisesRegex(
- TypeError, r'.*missing 1 required.*\n\n.*\.arg1.*arg1=1'):
+ with self.assertRaises(TypeError) as e:
fdl.build(cfg)
+ self.assertEqual(
+ e.exception.proxy_message, # pytype: disable=attribute-error
+ '\n\nFiddle context: failed to construct or call SampleClass at '
+ '.arg1 with positional arguments: (), keyword arguments: '
+ '(arg1=1).',
+ )
def test_build_raises_exception_on_call(self):
cfg = fdl.Config(raise_error)
msg = (
'My fancy exception\n\nFiddle context: failed to construct or call '
- 'raise_error at with arguments ()'
+ 'raise_error at with positional arguments: (), '
+ 'keyword arguments: ().'
)
with self.assertRaisesWithLiteralMatch(ValueError, msg):
fdl.build(cfg)
@@ -762,7 +859,9 @@ def test_build_error_path(self):
self.assertEqual(
e.exception.proxy_message, # pytype: disable=attribute-error
'\n\nFiddle context: failed to construct or call basic_fn at .'
- "arg1[1]['c'] with arguments (arg1=1)")
+ "arg1[1]['c'] with positional arguments: (), "
+ 'keyword arguments: (arg1=1).',
+ )
def test_multithreaded_build(self):
"""Two threads can each invoke build.build without interfering."""
@@ -926,11 +1025,80 @@ def test_update_callable_new_kwargs(self):
}
}, fdl.build(cfg))
- def test_update_callable_varargs(self):
- cfg = fdl.Config(fn_with_var_kwargs, 1, 2)
- with self.assertRaisesRegex(NotImplementedError,
- 'Variable positional arguments'):
- fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
+ # For `update_callable` involves variadic positional arguments, we test
+ # four patterns below.
+ # Pattern 1: *args -> *args
+ def test_original_and_new_callable_have_var_positaionl(self):
+ cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3)
+ fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
+ self.assertEqual(cfg.__arguments__, {'arg1': 1, 'args': [2], 'kwarg1': 3})
+ self.assertEqual(
+ fdl.build(cfg),
+ {'arg1': 1, 'args': (2,), 'kwarg1': 3, 'kwargs': {}},
+ )
+
+ def test_original_var_args_are_empty(self):
+ def foo(a, b, c, /, d=0, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+ def bar(a, /, b=0, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+ cfg = fdl.Config(foo, 1, 2, 3)
+ fdl.update_callable(cfg, bar)
+ self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 2, 'args': [3]})
+ self.assertEqual(
+ fdl.build(cfg),
+ {'a': 1, 'b': 2, 'args': (3,)},
+ )
+
+ def test_update_args_kwargs(self):
+ cfg = fdl.Config(fn_with_args_and_kwargs_only, 1, 2, 3, kwarg1=4, kwarg2=5)
+ cfg[0] = 10
+ cfg.kwarg1 = 40
+ config_lib.update_callable(cfg, fn_with_var_args_and_kwargs)
+ self.assertEqual(
+ cfg.__arguments__,
+ {'arg1': 10, 'args': [2, 3], 'kwarg1': 40, 'kwarg2': 5},
+ )
+ self.assertEqual(
+ fdl.build(cfg),
+ {'arg1': 10, 'args': (2, 3), 'kwarg1': 40, 'kwargs': {'kwarg2': 5}},
+ )
+
+ # Pattern 2: *args -> no *args
+ def test_original_callable_has_var_positaionl(self):
+ def positional_fn(a, b, c, /, kwarg1): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+ cfg = fdl.Config(fn_with_var_args, 1, 2, 3, kwarg1=4)
+ fdl.update_callable(cfg, positional_fn)
+ cfg[1] = 22
+ self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 22, 'c': 3, 'kwarg1': 4})
+ self.assertEqual(fdl.build(cfg), {'a': 1, 'b': 22, 'c': 3, 'kwarg1': 4})
+
+ # Pattern 3: no *args -> *args
+ def test_new_callable_has_var_positaionl(self):
+ cfg = fdl.Config(basic_fn, 1, 2, kwarg1=3)
+ fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
+ self.assertEqual(cfg.__arguments__, {'arg1': 1, 'arg2': 2, 'kwarg1': 3})
+ self.assertEqual(
+ fdl.build(cfg),
+ {'arg1': 1, 'args': (), 'kwarg1': 3, 'kwargs': {'arg2': 2}},
+ )
+
+ # Pattern 4: no *args -> no *args
+ def test_no_var_positional_w_different_names(self):
+ def foo(x, y, /, z=0): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+ def bar(a, b, c=0, /, d='d'): # pylint: disable=keyword-arg-before-vararg, unused-argument
+ return locals()
+
+ cfg = fdl.Config(foo, 1, 2)
+ fdl.update_callable(cfg, bar)
+ self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 2})
+ self.assertEqual(fdl.build(cfg), {'a': 1, 'b': 2, 'c': 0, 'd': 'd'})
def test_get_callable(self):
cfg = fdl.Config(basic_fn)
diff --git a/fiddle/_src/signatures.py b/fiddle/_src/signatures.py
index f975bba2..def523e2 100644
--- a/fiddle/_src/signatures.py
+++ b/fiddle/_src/signatures.py
@@ -17,7 +17,7 @@
import dataclasses
import inspect
-from typing import Any, Callable, Dict, Generic, Mapping, Tuple, Type
+from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type
import weakref
import typing_extensions
@@ -135,15 +135,33 @@ class SignatureInfo:
"""To store signature related information about the callable."""
signature: inspect.Signature
- has_var_keyword: bool = None
+ has_var_keyword: Optional[bool] = None
+ var_positional_name: Optional[str] = None
+ positional_arg_names: Optional[List[str]] = dataclasses.field(
+ default_factory=list
+ )
def __post_init__(self):
- if self.has_var_keyword is None:
- has_var_keyword = any(
- param.kind == param.VAR_KEYWORD
- for param in self.signature.parameters.values()
- )
- self.has_var_keyword = has_var_keyword
+ # During serilization, signature is set to None so no action is needed.
+ if self.signature is None:
+ return
+
+ # If *args exists, we must pass things before it in positional format. This
+ # list tracks those arguments.
+ maybe_positional_args = []
+ positional_only_args = []
+ for param in self.signature.parameters.values():
+ if param.kind == param.POSITIONAL_ONLY:
+ positional_only_args.append(param.name)
+ elif param.kind == param.POSITIONAL_OR_KEYWORD:
+ maybe_positional_args.append(param.name)
+ elif param.kind == param.VAR_POSITIONAL:
+ positional_only_args.extend(maybe_positional_args)
+ if not self.var_positional_name:
+ self.var_positional_name = param.name
+ elif param.kind == param.VAR_KEYWORD and self.has_var_keyword is None:
+ self.has_var_keyword = True
+ self.positional_arg_names = positional_only_args
@staticmethod
def signature_binding(fn_or_cls, *args, **kwargs) -> Any:
@@ -152,29 +170,37 @@ def signature_binding(fn_or_cls, *args, **kwargs) -> Any:
arguments = signature.bind_partial(*args, **kwargs).arguments
for name in list(arguments.keys()): # Make a copy in case we mutate.
param = signature.parameters[name]
- if param.kind == param.VAR_POSITIONAL:
- # TODO(b/197367863): Add *args support.
- err_msg = (
- 'Variable positional arguments (aka `*args`) not supported. '
- f'Found param `{name}` in `{fn_or_cls}`.'
- )
- raise NotImplementedError(err_msg)
- elif param.kind == param.VAR_KEYWORD:
+ if param.kind == param.VAR_KEYWORD:
arguments.update(arguments.pop(param.name))
return arguments
- def validate_param_name(self, name, fn_or_cls) -> None:
+ def get_positional_names(self) -> Tuple[List[str], List[str], str]:
+ """Get positional argument names."""
+ positional_only = []
+ keyword_or_positional = []
+ for param in self.signature.parameters.values():
+ if param.kind == param.POSITIONAL_ONLY:
+ positional_only.append(param.name)
+ elif param.kind == param.POSITIONAL_OR_KEYWORD:
+ keyword_or_positional.append(param.name)
+ return positional_only, keyword_or_positional, self.var_positional_name
+
+ def validate_param_name(
+ self, name, fn_or_cls, allow_postional_argument=False
+ ) -> None:
"""Raises an error if ``name`` is not a valid parameter name."""
param = self.signature.parameters.get(name)
if param is not None:
- if param.kind == param.POSITIONAL_ONLY:
- # TODO(b/197367863): Add positional-only arg support.
- raise NotImplementedError(
- 'Positional only arguments not supported. '
- f'Tried to set {name!r} on {fn_or_cls}'
+ if param.kind == param.POSITIONAL_ONLY and not allow_postional_argument:
+ raise AttributeError(
+ f'Cannot access POSITIONAL_ONLY parameter {name!r} on {fn_or_cls}'
)
- elif param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
+ elif param.kind == param.VAR_POSITIONAL and not allow_postional_argument:
+ raise AttributeError(
+ f'Cannot access VAR_POSITIONAL parameter {name!r} on {fn_or_cls}'
+ )
+ elif param.kind == param.VAR_KEYWORD:
# Just pretend it doesn't correspond to a valid parameter name... below
# a TypeError will be thrown unless there is a **kwargs parameter.
param = None
diff --git a/fiddle/config.py b/fiddle/config.py
index 1735025b..cc68b606 100644
--- a/fiddle/config.py
+++ b/fiddle/config.py
@@ -22,5 +22,6 @@
from fiddle._src.config import Config
from fiddle._src.config import NO_VALUE
from fiddle._src.config import NoValue
+from fiddle._src.config import VARARGS
from fiddle._src.partial import ArgFactory
from fiddle._src.partial import Partial
diff --git a/fiddle/examples/colabs/basic_api.ipynb b/fiddle/examples/colabs/basic_api.ipynb
index 9e13ad38..f783b8ee 100644
--- a/fiddle/examples/colabs/basic_api.ipynb
+++ b/fiddle/examples/colabs/basic_api.ipynb
@@ -401,50 +401,7 @@
"id": "G3IVzfktqAIu"
},
"source": [
- "but `*args` are currently unsupported,"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "colab": {
- "height": 34
- },
- "executionInfo": {
- "elapsed": 4,
- "status": "ok",
- "timestamp": 1692835092749,
- "user": {
- "displayName": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "p-r9vED0qNib",
- "outputId": "603b56d2-862b-4962-fc5c-10f03547aa75"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\u003cspan style=\"color: red\"\u003eNotImplementedError: Variable positional arguments (aka `*args`) not supported. Found param `args` in `\u003cfunction args_and_kwargs at 0x7f4ba5bd4670\u003e`.\u003c/span\u003e"
- ],
- "text/plain": [
- "\u003cIPython.core.display.HTML object\u003e"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "try:\n",
- " fdl.Config(args_and_kwargs, 4, 7)\n",
- "except NotImplementedError as e:\n",
- " display(HTML(f'\u003cspan style=\"color: red\"\u003eNotImplementedError: {e}\u003c/span\u003e'))\n",
- "else:\n",
- " raise AssertionError(\"This should raise an error!\")"
+ "# TODO(b/288893692): Update docs for posistional args."
]
},
{