Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fiddle/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, fn_or_cls: Union['Buildable', TypeOrCallableProducingT],
if metadata:
for tag in metadata.tags:
add_tag(self, field.name, tag)
if metadata.auto_config_factory and field.name not in arguments:
setattr(self, field.name,
metadata.auto_config_factory.as_buildable())

for name, value in arguments.items():
setattr(self, name, value)
Expand Down
43 changes: 36 additions & 7 deletions fiddle/experimental/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
TagOrTags = Union[tag_type.TagType, Collection[tag_type.TagType]]
_FIDDLE_DATACLASS_METADATA_KEY = object()

# In order to avoid a circular dependency, we just use Any as a proxy for
# auto_config.AutoConfig.
AutoConfig = Any


# TODO: Add kw_only=True when available.
@dataclasses.dataclass(frozen=True)
Expand All @@ -40,22 +44,31 @@ class FieldMetadata:

Attributes:
tags: A collection of tags to attach to the field.
auto_config_factory: An auto_config function to use as a factory, in both
normal and config-generating modes.
"""
tags: Collection[tag_type.TagType] = ()
auto_config_factory: Optional[AutoConfig] = None
# TODO: Add additional metadata types here (value validation rules,
# autofill / auto_config settings, etc).
# autofill settings, etc).


def field(*,
tags: Optional[TagOrTags] = None,
metadata: Optional[Mapping[Any, Any]] = None,
tags: TagOrTags = (),
metadata: Mapping[Any, Any] = types.MappingProxyType({}),
auto_config_factory: Optional[AutoConfig] = None,
**kwargs) -> Union[dataclasses.Field[Any], Any]:
"""A wrapper around dataclasses.field to add optional Fiddle metadata.

Args:
tags: One or more tags to attach to the `fdl.Buildable`'s argument
corresponding to the field.
metadata: Any additional metadata to include.
auto_config_factory: A `default_factory` that can also instantiate
sub-configuration. If the this function is a reasonable auto_config
function, then we should still have the property that
`fdl.build(fdl.Config(MyDataclass)) == MyDataclass()`, since the
configuration generated should be equal to the default_factory invocation.
**kwargs: All other kwargs are passed to `dataclasses.field`; see the
documentation on `dataclasses.field` for valid arguments.

Expand All @@ -68,13 +81,29 @@ def field(*,
if isinstance(tags, tag_type.TagType):
tags = (tags,)

metadata: Mapping[Any, Any] = types.MappingProxyType(metadata or {})
metadata = {
**metadata, _FIDDLE_DATACLASS_METADATA_KEY: FieldMetadata(tags=tags)
}
if auto_config_factory is not None:
# Note: A rudimentary type check here, to avoid circular imports.
if not hasattr(auto_config_factory, "as_buildable"):
raise TypeError("Expected `auto_config_factory` to have an "
"as_buildable method.")
# if "default" in kwargs or "default_factory" in kwargs:
# raise TypeError("You cannot create a field that has "
# "`auto_config_factory` and `default`/`default_factory` "
# "set.")
kwargs["default_factory"] = auto_config_factory

fiddle_metadata = FieldMetadata(
tags=tags, auto_config_factory=auto_config_factory)
metadata = {**metadata, _FIDDLE_DATACLASS_METADATA_KEY: fiddle_metadata}
return dataclasses.field(metadata=metadata, **kwargs)


def create_metadata(tags, auto_config_factory):
fiddle_metadata = FieldMetadata(
tags=tags, auto_config_factory=auto_config_factory)
return {_FIDDLE_DATACLASS_METADATA_KEY: fiddle_metadata}


def field_metadata(
field_object: dataclasses.Field[Any]) -> Optional[FieldMetadata]:
"""Retrieves the Fiddle-specific metadata (if present) on `field`."""
Expand Down
61 changes: 61 additions & 0 deletions fiddle/experimental/dataclasses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import dataclasses
import types
from typing import Any, Dict

from absl.testing import absltest
import fiddle as fdl
from fiddle.experimental import auto_config
from fiddle.experimental import dataclasses as fdl_dc


Expand All @@ -38,6 +40,40 @@ class ATaggedType:
double_tagged: str = fdl_dc.field(
tags=(AdditionalTag, SampleTag), default_factory=lambda: 'other_field')

@classmethod
@auto_config.auto_config
def defaults(cls):
return cls(untagged='untagged_foo')


def test_fn():
return 1


@auto_config.auto_config
def nested_structure():
return {'foo': [test_fn(), (2, 3)]}


@dataclasses.dataclass
class AnAutoconfigType:
tagged: ATaggedType = dataclasses.field(
# metadata=fdl_dc.create_metadata((),
# auto_config_factory=ATaggedType.defaults),
default_factory=ATaggedType.defaults)
another_default: Dict[str, Any] = dataclasses.field(
# metadata=fdl_dc.create_metadata((),
# auto_config_factory=ATaggedType.defaults),
default_factory=nested_structure)


@dataclasses.dataclass
class AncestorType:
# We might want to make this more compact.
child: AnAutoconfigType = fdl_dc.field(
auto_config_factory=auto_config.AutoConfig(
AnAutoconfigType, lambda: fdl.Config(AnAutoconfigType), True))


class DataclassesTest(absltest.TestCase):

Expand All @@ -59,6 +95,31 @@ def test_metadata_passthrough(self):
self.assertIn('something', constructed_field.metadata)
self.assertEqual(4, constructed_field.metadata['something'])

def test_auto_config_basic_equality(self):
self.assertEqual(
fdl.build(fdl.Config(AnAutoconfigType)), AnAutoconfigType())
self.assertEqual(fdl.build(fdl.Config(AncestorType)), AncestorType())

def test_auto_config_override_equality(self):
self.assertEqual(
fdl.build(fdl.Config(AnAutoconfigType, another_default={3: 4})),
AnAutoconfigType(another_default={3: 4}))

def test_auto_config_field_init(self):
config = fdl.Config(AnAutoconfigType)
config.another_default['foo'][1] += (4,)
obj = fdl.build(config)
self.assertEqual(obj.another_default, {'foo': [1, (2, 3, 4)]})

# def test_invalid_definition_with_defaults(self):
# with self.assertRaises(TypeError):
# fdl_dc.field(auto_config_factory=nested_structure, default=4)
# with self.assertRaises(TypeError):
# fdl_dc.field(
# auto_config_factory=nested_structure, default_factory=lambda: 4)
# with self.assertRaises(TypeError):
# fdl_dc.field(auto_config_factory=lambda: 4)


if __name__ == '__main__':
absltest.main()