Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 114 additions & 36 deletions fiddle/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,58 @@ def __copy__(self):
_UNSET_SENTINEL = object()


_defaults_aware_traverser_registry = daglish.NodeTraverserRegistry(
use_fallback=True
)


def _buildable_flatten(
buildable: Buildable, include_defaults: bool = False
) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]:
"""Implement Buildable.__flatten__ method."""
arguments = ordered_arguments(buildable, include_defaults=include_defaults)
keys = tuple(arguments.keys())
values = tuple(arguments.values())
argument_tags = {
name: frozenset(tags)
for name, tags in buildable.__argument_tags__.items()
if tags # Don't include empty sets.
}
argument_history = {
name: tuple(entries)
for name, entries in buildable.__argument_history__.items()
}
metadata = BuildableTraverserMetadata(
fn_or_cls=buildable.__fn_or_cls__,
argument_names=keys,
argument_tags=argument_tags,
argument_history=argument_history,
)
return values, metadata


def _buildable_path_elements(
buildable: Buildable, include_defaults: bool = True
) -> Tuple[daglish.Attr]:
"""Implement Buildable.__path_elements__ method."""
return tuple(
daglish.Attr(name)
for name in ordered_arguments(
buildable, include_defaults=include_defaults
).keys()
)


def _register_buildable_defaults_aware_traversers(cls: Type[Buildable]):
"""Registers defaults aware traversal routines for buildable subclasses."""
_defaults_aware_traverser_registry.register_node_traverser(
cls,
flatten_fn=functools.partial(_buildable_flatten, include_defaults=True),
unflatten_fn=cls.__unflatten__,
path_elements_fn=_buildable_path_elements,
)


class BuildableTraverserMetadata(NamedTuple):
"""Metadata for a Buildable.

Expand Down Expand Up @@ -192,32 +244,15 @@ def __init_subclass__(cls):
unflatten_fn=cls.__unflatten__,
path_elements_fn=lambda x: x.__path_elements__(),
)
_register_buildable_defaults_aware_traversers(cls)

@abc.abstractmethod
def __build__(self, *args, **kwargs):
"""Builds output for this instance; see subclasses for details."""
raise NotImplementedError()

def __flatten__(self) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]:
arguments = ordered_arguments(self)
keys = tuple(arguments.keys())
values = tuple(arguments.values())
argument_tags = {
name: frozenset(tags)
for name, tags in self.__argument_tags__.items()
if tags # Don't include empty sets.
}
argument_history = {
name: tuple(entries)
for name, entries in self.__argument_history__.items()
}
metadata = BuildableTraverserMetadata(
fn_or_cls=self.__fn_or_cls__,
argument_names=keys,
argument_tags=argument_tags,
argument_history=argument_history,
)
return values, metadata
return _buildable_flatten(self, include_defaults=False)

@classmethod
def __unflatten__(
Expand All @@ -230,8 +265,8 @@ def __unflatten__(
object.__setattr__(rebuilt, '__arguments__', metadata.arguments(values))
return rebuilt

def __path_elements__(self):
return tuple(daglish.Attr(name) for name in ordered_arguments(self).keys())
def __path_elements__(self) -> Tuple[daglish.Attr]:
return _buildable_path_elements(self, include_defaults=False)

def __getattr__(self, name: str):
"""Get parameter with given ``name``."""
Expand Down Expand Up @@ -446,22 +481,65 @@ class being configured, and then checks for equality in the configured
Returns:
``True`` if ``self`` equals ``other``, ``False`` if not.
"""
if type(self) is not type(other):
return False
if self.__fn_or_cls__ != other.__fn_or_cls__:
return False
assert self._has_var_keyword == other._has_var_keyword, (
'Internal invariant violated: has_var_keyword should be the same if '
"__fn_or_cls__'s are the same."
)

missing = object()
for key in set(self.__arguments__) | set(other.__arguments__):
v1 = getattr(self, key, missing)
v2 = getattr(other, key, missing)
if (v1 is not missing or v2 is not missing) and v1 != v2:
def compare_buildable(x, y, check_dag=False):
assert isinstance(x, Buildable)
if type(x) is not type(y):
return False
return True
if x.__fn_or_cls__ != y.__fn_or_cls__:
return False
assert x._has_var_keyword == y._has_var_keyword, ( # pylint: disable=protected-access
'Internal invariant violated: has_var_keyword should be the same if '
"__fn_or_cls__'s are the same."
)

missing = object()
for key in set(x.__arguments__) | set(y.__arguments__):
v1 = getattr(x, key, missing)
v2 = getattr(y, key, missing)
assert not (v1 is missing and v2 is missing)
if v1 is missing or v2 is missing:
return False
if isinstance(v1, Buildable) and isinstance(v2, Buildable):
if not compare_buildable(v1, v2, check_dag=False):
return False
if v1 != v2:
return False

# Compare the DAG structure.
# The DAG stracture comparison must traverse the whole DAG and sort the
# result by path, which is expensive. Thus, we compare values first so
# that most unequal cases will not reach the expensive DAG compare step.
if check_dag:
x_elements = list(
daglish.iterate(
x,
memoized=True,
memoize_internables=True,
registry=_defaults_aware_traverser_registry,
)
)
y_elements = list(
daglish.iterate(
y,
memoized=True,
memoize_internables=True,
registry=_defaults_aware_traverser_registry,
)
)
x_elements = sorted(x_elements, key=lambda x: x[1])
y_elements = sorted(y_elements, key=lambda x: x[1])

if len(x_elements) != len(y_elements):
return False
for v1, v2 in zip(x_elements, y_elements):
_, x_path = v1
_, y_path = v2
if x_path != y_path:
return False

return True

return compare_buildable(self, other, check_dag=True)

def __getstate__(self):
"""Gets pickle serialization state, removing some fields.
Expand Down
18 changes: 18 additions & 0 deletions fiddle/_src/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,24 @@ def fn_with_non_comparable_default(
self.assertIsInstance(value1, ClassWithDisabledEquality)
self.assertIsInstance(value2, ClassWithDisabledEquality)

def test_config_dag_structure_comparison(self):
a = fdl.Config(SampleClass, 1, 2)
b = fdl.Config(SampleClass, 1, 2)
with self.subTest('python_list'):
x = [a, a]
y = [a, b]
self.assertEqual(x, y)

with self.subTest('node_sharing_detection'):
x = fdl.Config(SampleClass, a, b)
y = fdl.Config(SampleClass, a, a)
self.assertNotEqual(x, y)

with self.subTest('node_sharing_difference'):
x = fdl.Config(SampleClass, a, b, b)
y = fdl.Config(SampleClass, a, a, b)
self.assertNotEqual(x, y)

def test_unsetting_argument(self):
fn_config = fdl.Config(basic_fn)
fn_config.arg1 = 3
Expand Down
41 changes: 39 additions & 2 deletions fiddle/_src/daglish.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def code(self) -> str:
def follow(self, container) -> Any:
"""Returns the element of `container` specified by this path element."""

def __lt__(self, other: PathElement) -> bool:
"""Define the less than relation for PathElement."""
if type(self) is not type(other):
return str(type(self)) < str(type(other))
else:
raise NotImplementedError(
"__lt__ relation should be handled by subclasses of PathElement."
)


@dataclasses.dataclass(frozen=True)
class Index(PathElement):
Expand All @@ -50,6 +59,12 @@ def code(self) -> str:
def follow(self, container: Union[List[Any], Tuple[Any, ...]]) -> Any:
return container[self.index]

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.index < other.index
else:
return super().__lt__(other)


@dataclasses.dataclass(frozen=True)
class Key(PathElement):
Expand All @@ -63,6 +78,12 @@ def code(self) -> str:
def follow(self, container: Dict[Any, Any]) -> Any:
return container[self.key]

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.key < other.key
else:
return super().__lt__(other)


@dataclasses.dataclass(frozen=True)
class Attr(PathElement):
Expand All @@ -76,6 +97,12 @@ def code(self) -> str:
def follow(self, container: Any) -> Any:
return getattr(container, self.name)

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.name < other.name
else:
return super().__lt__(other)


class BuildableAttr(Attr):
"""An attribute of a Buildable."""
Expand Down Expand Up @@ -727,6 +754,7 @@ def traverse(value, state: State):
def iterate(
value: Any,
memoized: bool = True,
memoize_internables: bool = True,
registry: NodeTraverserRegistry = _default_traverser_registry,
) -> Iterable[Tuple[Any, Path]]:
"""Iterates through values in a DAG.
Expand All @@ -743,6 +771,8 @@ def iterate(
memoized: Whether to yield shared nodes only once. Defaults to True. With
this setting, you will only see one path (which is somewhat arbitrary) to
shared nodes.
memoize_internables: Whether to memoize Python internable values. Check the
docstring of MemoizedTraversal for details.
registry: Override to the NodeTraverserRegistry; this is a low-level setting
for traversing into custom data types.

Expand All @@ -756,6 +786,13 @@ def _traverse(node, state: State):
for sub_result in state.flattened_map_children(node).values:
yield from sub_result

traversal_cls = MemoizedTraversal if memoized else BasicTraversal
traversal: Traversal = traversal_cls(_traverse, value, registry=registry)
if memoized:
traversal = MemoizedTraversal(
_traverse,
value,
registry=registry,
memoize_internables=memoize_internables,
)
else:
traversal = BasicTraversal(_traverse, value, registry=registry)
return _traverse(value, traversal.initial_state())