From 618bc20c7955a9e2642423b219cad6fcbacf79ab Mon Sep 17 00:00:00 2001
From: Zhufeng Pan
Date: Tue, 27 Jun 2023 11:14:28 -0700
Subject: [PATCH] For testing if `memoize_internables` is a must
Enable fiddle.buildable `__eq__` detect DAG difference.
An extra DAG structure check is added for the top-level Buildable.
PiperOrigin-RevId: 543799794
---
fiddle/_src/config.py | 150 ++++++++++++++++++++++++++++---------
fiddle/_src/config_test.py | 18 +++++
fiddle/_src/daglish.py | 41 +++++++++-
3 files changed, 171 insertions(+), 38 deletions(-)
diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py
index 032a9f78..37ee10a8 100644
--- a/fiddle/_src/config.py
+++ b/fiddle/_src/config.py
@@ -64,6 +64,58 @@ def __copy__(self):
_UNSET_SENTINEL = object()
+_defaults_aware_traverser_registry = daglish.NodeTraverserRegistry(
+ use_fallback=True
+)
+
+
+def _buildable_flatten(
+ buildable: Buildable, include_defaults: bool = False
+) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]:
+ """Implement Buildable.__flatten__ method."""
+ arguments = ordered_arguments(buildable, include_defaults=include_defaults)
+ keys = tuple(arguments.keys())
+ values = tuple(arguments.values())
+ argument_tags = {
+ name: frozenset(tags)
+ for name, tags in buildable.__argument_tags__.items()
+ if tags # Don't include empty sets.
+ }
+ argument_history = {
+ name: tuple(entries)
+ for name, entries in buildable.__argument_history__.items()
+ }
+ metadata = BuildableTraverserMetadata(
+ fn_or_cls=buildable.__fn_or_cls__,
+ argument_names=keys,
+ argument_tags=argument_tags,
+ argument_history=argument_history,
+ )
+ return values, metadata
+
+
+def _buildable_path_elements(
+ buildable: Buildable, include_defaults: bool = True
+) -> Tuple[daglish.Attr]:
+ """Implement Buildable.__path_elements__ method."""
+ return tuple(
+ daglish.Attr(name)
+ for name in ordered_arguments(
+ buildable, include_defaults=include_defaults
+ ).keys()
+ )
+
+
+def _register_buildable_defaults_aware_traversers(cls: Type[Buildable]):
+ """Registers defaults aware traversal routines for buildable subclasses."""
+ _defaults_aware_traverser_registry.register_node_traverser(
+ cls,
+ flatten_fn=functools.partial(_buildable_flatten, include_defaults=True),
+ unflatten_fn=cls.__unflatten__,
+ path_elements_fn=_buildable_path_elements,
+ )
+
+
class BuildableTraverserMetadata(NamedTuple):
"""Metadata for a Buildable.
@@ -192,6 +244,7 @@ def __init_subclass__(cls):
unflatten_fn=cls.__unflatten__,
path_elements_fn=lambda x: x.__path_elements__(),
)
+ _register_buildable_defaults_aware_traversers(cls)
@abc.abstractmethod
def __build__(self, *args, **kwargs):
@@ -199,25 +252,7 @@ def __build__(self, *args, **kwargs):
raise NotImplementedError()
def __flatten__(self) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]:
- arguments = ordered_arguments(self)
- keys = tuple(arguments.keys())
- values = tuple(arguments.values())
- argument_tags = {
- name: frozenset(tags)
- for name, tags in self.__argument_tags__.items()
- if tags # Don't include empty sets.
- }
- argument_history = {
- name: tuple(entries)
- for name, entries in self.__argument_history__.items()
- }
- metadata = BuildableTraverserMetadata(
- fn_or_cls=self.__fn_or_cls__,
- argument_names=keys,
- argument_tags=argument_tags,
- argument_history=argument_history,
- )
- return values, metadata
+ return _buildable_flatten(self, include_defaults=False)
@classmethod
def __unflatten__(
@@ -230,8 +265,8 @@ def __unflatten__(
object.__setattr__(rebuilt, '__arguments__', metadata.arguments(values))
return rebuilt
- def __path_elements__(self):
- return tuple(daglish.Attr(name) for name in ordered_arguments(self).keys())
+ def __path_elements__(self) -> Tuple[daglish.Attr]:
+ return _buildable_path_elements(self, include_defaults=False)
def __getattr__(self, name: str):
"""Get parameter with given ``name``."""
@@ -446,22 +481,65 @@ class being configured, and then checks for equality in the configured
Returns:
``True`` if ``self`` equals ``other``, ``False`` if not.
"""
- if type(self) is not type(other):
- return False
- if self.__fn_or_cls__ != other.__fn_or_cls__:
- return False
- assert self._has_var_keyword == other._has_var_keyword, (
- 'Internal invariant violated: has_var_keyword should be the same if '
- "__fn_or_cls__'s are the same."
- )
-
- missing = object()
- for key in set(self.__arguments__) | set(other.__arguments__):
- v1 = getattr(self, key, missing)
- v2 = getattr(other, key, missing)
- if (v1 is not missing or v2 is not missing) and v1 != v2:
+ def compare_buildable(x, y, check_dag=False):
+ assert isinstance(x, Buildable)
+ if type(x) is not type(y):
return False
- return True
+ if x.__fn_or_cls__ != y.__fn_or_cls__:
+ return False
+ assert x._has_var_keyword == y._has_var_keyword, ( # pylint: disable=protected-access
+ 'Internal invariant violated: has_var_keyword should be the same if '
+ "__fn_or_cls__'s are the same."
+ )
+
+ missing = object()
+ for key in set(x.__arguments__) | set(y.__arguments__):
+ v1 = getattr(x, key, missing)
+ v2 = getattr(y, key, missing)
+ assert not (v1 is missing and v2 is missing)
+ if v1 is missing or v2 is missing:
+ return False
+ if isinstance(v1, Buildable) and isinstance(v2, Buildable):
+ if not compare_buildable(v1, v2, check_dag=False):
+ return False
+ if v1 != v2:
+ return False
+
+ # Compare the DAG structure.
+ # The DAG stracture comparison must traverse the whole DAG and sort the
+ # result by path, which is expensive. Thus, we compare values first so
+ # that most unequal cases will not reach the expensive DAG compare step.
+ if check_dag:
+ x_elements = list(
+ daglish.iterate(
+ x,
+ memoized=True,
+ memoize_internables=True,
+ registry=_defaults_aware_traverser_registry,
+ )
+ )
+ y_elements = list(
+ daglish.iterate(
+ y,
+ memoized=True,
+ memoize_internables=True,
+ registry=_defaults_aware_traverser_registry,
+ )
+ )
+ x_elements = sorted(x_elements, key=lambda x: x[1])
+ y_elements = sorted(y_elements, key=lambda x: x[1])
+
+ if len(x_elements) != len(y_elements):
+ return False
+ for v1, v2 in zip(x_elements, y_elements):
+ _, x_path = v1
+ _, y_path = v2
+ if x_path != y_path:
+ return False
+
+ return True
+
+ return compare_buildable(self, other, check_dag=True)
def __getstate__(self):
"""Gets pickle serialization state, removing some fields.
diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py
index 827e220a..6dd610c1 100644
--- a/fiddle/_src/config_test.py
+++ b/fiddle/_src/config_test.py
@@ -517,6 +517,24 @@ def fn_with_non_comparable_default(
self.assertIsInstance(value1, ClassWithDisabledEquality)
self.assertIsInstance(value2, ClassWithDisabledEquality)
+ def test_config_dag_structure_comparison(self):
+ a = fdl.Config(SampleClass, 1, 2)
+ b = fdl.Config(SampleClass, 1, 2)
+ with self.subTest('python_list'):
+ x = [a, a]
+ y = [a, b]
+ self.assertEqual(x, y)
+
+ with self.subTest('node_sharing_detection'):
+ x = fdl.Config(SampleClass, a, b)
+ y = fdl.Config(SampleClass, a, a)
+ self.assertNotEqual(x, y)
+
+ with self.subTest('node_sharing_difference'):
+ x = fdl.Config(SampleClass, a, b, b)
+ y = fdl.Config(SampleClass, a, a, b)
+ self.assertNotEqual(x, y)
+
def test_unsetting_argument(self):
fn_config = fdl.Config(basic_fn)
fn_config.arg1 = 3
diff --git a/fiddle/_src/daglish.py b/fiddle/_src/daglish.py
index 7d086ce7..4cad7aca 100644
--- a/fiddle/_src/daglish.py
+++ b/fiddle/_src/daglish.py
@@ -37,6 +37,15 @@ def code(self) -> str:
def follow(self, container) -> Any:
"""Returns the element of `container` specified by this path element."""
+ def __lt__(self, other: PathElement) -> bool:
+ """Define the less than relation for PathElement."""
+ if type(self) is not type(other):
+ return str(type(self)) < str(type(other))
+ else:
+ raise NotImplementedError(
+ "__lt__ relation should be handled by subclasses of PathElement."
+ )
+
@dataclasses.dataclass(frozen=True)
class Index(PathElement):
@@ -50,6 +59,12 @@ def code(self) -> str:
def follow(self, container: Union[List[Any], Tuple[Any, ...]]) -> Any:
return container[self.index]
+ def __lt__(self, other: PathElement) -> bool:
+ if type(self) is type(other):
+ return self.index < other.index
+ else:
+ return super().__lt__(other)
+
@dataclasses.dataclass(frozen=True)
class Key(PathElement):
@@ -63,6 +78,12 @@ def code(self) -> str:
def follow(self, container: Dict[Any, Any]) -> Any:
return container[self.key]
+ def __lt__(self, other: PathElement) -> bool:
+ if type(self) is type(other):
+ return self.key < other.key
+ else:
+ return super().__lt__(other)
+
@dataclasses.dataclass(frozen=True)
class Attr(PathElement):
@@ -76,6 +97,12 @@ def code(self) -> str:
def follow(self, container: Any) -> Any:
return getattr(container, self.name)
+ def __lt__(self, other: PathElement) -> bool:
+ if type(self) is type(other):
+ return self.name < other.name
+ else:
+ return super().__lt__(other)
+
class BuildableAttr(Attr):
"""An attribute of a Buildable."""
@@ -727,6 +754,7 @@ def traverse(value, state: State):
def iterate(
value: Any,
memoized: bool = True,
+ memoize_internables: bool = True,
registry: NodeTraverserRegistry = _default_traverser_registry,
) -> Iterable[Tuple[Any, Path]]:
"""Iterates through values in a DAG.
@@ -743,6 +771,8 @@ def iterate(
memoized: Whether to yield shared nodes only once. Defaults to True. With
this setting, you will only see one path (which is somewhat arbitrary) to
shared nodes.
+ memoize_internables: Whether to memoize Python internable values. Check the
+ docstring of MemoizedTraversal for details.
registry: Override to the NodeTraverserRegistry; this is a low-level setting
for traversing into custom data types.
@@ -756,6 +786,13 @@ def _traverse(node, state: State):
for sub_result in state.flattened_map_children(node).values:
yield from sub_result
- traversal_cls = MemoizedTraversal if memoized else BasicTraversal
- traversal: Traversal = traversal_cls(_traverse, value, registry=registry)
+ if memoized:
+ traversal = MemoizedTraversal(
+ _traverse,
+ value,
+ registry=registry,
+ memoize_internables=memoize_internables,
+ )
+ else:
+ traversal = BasicTraversal(_traverse, value, registry=registry)
return _traverse(value, traversal.initial_state())