From 618bc20c7955a9e2642423b219cad6fcbacf79ab Mon Sep 17 00:00:00 2001 From: Zhufeng Pan Date: Tue, 27 Jun 2023 11:14:28 -0700 Subject: [PATCH] For testing if `memoize_internables` is a must Enable fiddle.buildable `__eq__` detect DAG difference. An extra DAG structure check is added for the top-level Buildable. PiperOrigin-RevId: 543799794 --- fiddle/_src/config.py | 150 ++++++++++++++++++++++++++++--------- fiddle/_src/config_test.py | 18 +++++ fiddle/_src/daglish.py | 41 +++++++++- 3 files changed, 171 insertions(+), 38 deletions(-) diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py index 032a9f78..37ee10a8 100644 --- a/fiddle/_src/config.py +++ b/fiddle/_src/config.py @@ -64,6 +64,58 @@ def __copy__(self): _UNSET_SENTINEL = object() +_defaults_aware_traverser_registry = daglish.NodeTraverserRegistry( + use_fallback=True +) + + +def _buildable_flatten( + buildable: Buildable, include_defaults: bool = False +) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]: + """Implement Buildable.__flatten__ method.""" + arguments = ordered_arguments(buildable, include_defaults=include_defaults) + keys = tuple(arguments.keys()) + values = tuple(arguments.values()) + argument_tags = { + name: frozenset(tags) + for name, tags in buildable.__argument_tags__.items() + if tags # Don't include empty sets. + } + argument_history = { + name: tuple(entries) + for name, entries in buildable.__argument_history__.items() + } + metadata = BuildableTraverserMetadata( + fn_or_cls=buildable.__fn_or_cls__, + argument_names=keys, + argument_tags=argument_tags, + argument_history=argument_history, + ) + return values, metadata + + +def _buildable_path_elements( + buildable: Buildable, include_defaults: bool = True +) -> Tuple[daglish.Attr]: + """Implement Buildable.__path_elements__ method.""" + return tuple( + daglish.Attr(name) + for name in ordered_arguments( + buildable, include_defaults=include_defaults + ).keys() + ) + + +def _register_buildable_defaults_aware_traversers(cls: Type[Buildable]): + """Registers defaults aware traversal routines for buildable subclasses.""" + _defaults_aware_traverser_registry.register_node_traverser( + cls, + flatten_fn=functools.partial(_buildable_flatten, include_defaults=True), + unflatten_fn=cls.__unflatten__, + path_elements_fn=_buildable_path_elements, + ) + + class BuildableTraverserMetadata(NamedTuple): """Metadata for a Buildable. @@ -192,6 +244,7 @@ def __init_subclass__(cls): unflatten_fn=cls.__unflatten__, path_elements_fn=lambda x: x.__path_elements__(), ) + _register_buildable_defaults_aware_traversers(cls) @abc.abstractmethod def __build__(self, *args, **kwargs): @@ -199,25 +252,7 @@ def __build__(self, *args, **kwargs): raise NotImplementedError() def __flatten__(self) -> Tuple[Tuple[Any, ...], BuildableTraverserMetadata]: - arguments = ordered_arguments(self) - keys = tuple(arguments.keys()) - values = tuple(arguments.values()) - argument_tags = { - name: frozenset(tags) - for name, tags in self.__argument_tags__.items() - if tags # Don't include empty sets. - } - argument_history = { - name: tuple(entries) - for name, entries in self.__argument_history__.items() - } - metadata = BuildableTraverserMetadata( - fn_or_cls=self.__fn_or_cls__, - argument_names=keys, - argument_tags=argument_tags, - argument_history=argument_history, - ) - return values, metadata + return _buildable_flatten(self, include_defaults=False) @classmethod def __unflatten__( @@ -230,8 +265,8 @@ def __unflatten__( object.__setattr__(rebuilt, '__arguments__', metadata.arguments(values)) return rebuilt - def __path_elements__(self): - return tuple(daglish.Attr(name) for name in ordered_arguments(self).keys()) + def __path_elements__(self) -> Tuple[daglish.Attr]: + return _buildable_path_elements(self, include_defaults=False) def __getattr__(self, name: str): """Get parameter with given ``name``.""" @@ -446,22 +481,65 @@ class being configured, and then checks for equality in the configured Returns: ``True`` if ``self`` equals ``other``, ``False`` if not. """ - if type(self) is not type(other): - return False - if self.__fn_or_cls__ != other.__fn_or_cls__: - return False - assert self._has_var_keyword == other._has_var_keyword, ( - 'Internal invariant violated: has_var_keyword should be the same if ' - "__fn_or_cls__'s are the same." - ) - - missing = object() - for key in set(self.__arguments__) | set(other.__arguments__): - v1 = getattr(self, key, missing) - v2 = getattr(other, key, missing) - if (v1 is not missing or v2 is not missing) and v1 != v2: + def compare_buildable(x, y, check_dag=False): + assert isinstance(x, Buildable) + if type(x) is not type(y): return False - return True + if x.__fn_or_cls__ != y.__fn_or_cls__: + return False + assert x._has_var_keyword == y._has_var_keyword, ( # pylint: disable=protected-access + 'Internal invariant violated: has_var_keyword should be the same if ' + "__fn_or_cls__'s are the same." + ) + + missing = object() + for key in set(x.__arguments__) | set(y.__arguments__): + v1 = getattr(x, key, missing) + v2 = getattr(y, key, missing) + assert not (v1 is missing and v2 is missing) + if v1 is missing or v2 is missing: + return False + if isinstance(v1, Buildable) and isinstance(v2, Buildable): + if not compare_buildable(v1, v2, check_dag=False): + return False + if v1 != v2: + return False + + # Compare the DAG structure. + # The DAG stracture comparison must traverse the whole DAG and sort the + # result by path, which is expensive. Thus, we compare values first so + # that most unequal cases will not reach the expensive DAG compare step. + if check_dag: + x_elements = list( + daglish.iterate( + x, + memoized=True, + memoize_internables=True, + registry=_defaults_aware_traverser_registry, + ) + ) + y_elements = list( + daglish.iterate( + y, + memoized=True, + memoize_internables=True, + registry=_defaults_aware_traverser_registry, + ) + ) + x_elements = sorted(x_elements, key=lambda x: x[1]) + y_elements = sorted(y_elements, key=lambda x: x[1]) + + if len(x_elements) != len(y_elements): + return False + for v1, v2 in zip(x_elements, y_elements): + _, x_path = v1 + _, y_path = v2 + if x_path != y_path: + return False + + return True + + return compare_buildable(self, other, check_dag=True) def __getstate__(self): """Gets pickle serialization state, removing some fields. diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py index 827e220a..6dd610c1 100644 --- a/fiddle/_src/config_test.py +++ b/fiddle/_src/config_test.py @@ -517,6 +517,24 @@ def fn_with_non_comparable_default( self.assertIsInstance(value1, ClassWithDisabledEquality) self.assertIsInstance(value2, ClassWithDisabledEquality) + def test_config_dag_structure_comparison(self): + a = fdl.Config(SampleClass, 1, 2) + b = fdl.Config(SampleClass, 1, 2) + with self.subTest('python_list'): + x = [a, a] + y = [a, b] + self.assertEqual(x, y) + + with self.subTest('node_sharing_detection'): + x = fdl.Config(SampleClass, a, b) + y = fdl.Config(SampleClass, a, a) + self.assertNotEqual(x, y) + + with self.subTest('node_sharing_difference'): + x = fdl.Config(SampleClass, a, b, b) + y = fdl.Config(SampleClass, a, a, b) + self.assertNotEqual(x, y) + def test_unsetting_argument(self): fn_config = fdl.Config(basic_fn) fn_config.arg1 = 3 diff --git a/fiddle/_src/daglish.py b/fiddle/_src/daglish.py index 7d086ce7..4cad7aca 100644 --- a/fiddle/_src/daglish.py +++ b/fiddle/_src/daglish.py @@ -37,6 +37,15 @@ def code(self) -> str: def follow(self, container) -> Any: """Returns the element of `container` specified by this path element.""" + def __lt__(self, other: PathElement) -> bool: + """Define the less than relation for PathElement.""" + if type(self) is not type(other): + return str(type(self)) < str(type(other)) + else: + raise NotImplementedError( + "__lt__ relation should be handled by subclasses of PathElement." + ) + @dataclasses.dataclass(frozen=True) class Index(PathElement): @@ -50,6 +59,12 @@ def code(self) -> str: def follow(self, container: Union[List[Any], Tuple[Any, ...]]) -> Any: return container[self.index] + def __lt__(self, other: PathElement) -> bool: + if type(self) is type(other): + return self.index < other.index + else: + return super().__lt__(other) + @dataclasses.dataclass(frozen=True) class Key(PathElement): @@ -63,6 +78,12 @@ def code(self) -> str: def follow(self, container: Dict[Any, Any]) -> Any: return container[self.key] + def __lt__(self, other: PathElement) -> bool: + if type(self) is type(other): + return self.key < other.key + else: + return super().__lt__(other) + @dataclasses.dataclass(frozen=True) class Attr(PathElement): @@ -76,6 +97,12 @@ def code(self) -> str: def follow(self, container: Any) -> Any: return getattr(container, self.name) + def __lt__(self, other: PathElement) -> bool: + if type(self) is type(other): + return self.name < other.name + else: + return super().__lt__(other) + class BuildableAttr(Attr): """An attribute of a Buildable.""" @@ -727,6 +754,7 @@ def traverse(value, state: State): def iterate( value: Any, memoized: bool = True, + memoize_internables: bool = True, registry: NodeTraverserRegistry = _default_traverser_registry, ) -> Iterable[Tuple[Any, Path]]: """Iterates through values in a DAG. @@ -743,6 +771,8 @@ def iterate( memoized: Whether to yield shared nodes only once. Defaults to True. With this setting, you will only see one path (which is somewhat arbitrary) to shared nodes. + memoize_internables: Whether to memoize Python internable values. Check the + docstring of MemoizedTraversal for details. registry: Override to the NodeTraverserRegistry; this is a low-level setting for traversing into custom data types. @@ -756,6 +786,13 @@ def _traverse(node, state: State): for sub_result in state.flattened_map_children(node).values: yield from sub_result - traversal_cls = MemoizedTraversal if memoized else BasicTraversal - traversal: Traversal = traversal_cls(_traverse, value, registry=registry) + if memoized: + traversal = MemoizedTraversal( + _traverse, + value, + registry=registry, + memoize_internables=memoize_internables, + ) + else: + traversal = BasicTraversal(_traverse, value, registry=registry) return _traverse(value, traversal.initial_state())