From c95c0b66b6335ba25361b66062e656260121f4e7 Mon Sep 17 00:00:00 2001
From: Fiddle-Config Team
Date: Thu, 20 Jul 2023 19:49:09 -0700
Subject: [PATCH] Generate experiment diffs from baselines.
PiperOrigin-RevId: 549813869
---
fiddle/_src/codegen/codegen_diff.py | 60 ++++++++++++++++++++++++++---
1 file changed, 54 insertions(+), 6 deletions(-)
diff --git a/fiddle/_src/codegen/codegen_diff.py b/fiddle/_src/codegen/codegen_diff.py
index bf43ddc7..05922706 100644
--- a/fiddle/_src/codegen/codegen_diff.py
+++ b/fiddle/_src/codegen/codegen_diff.py
@@ -16,10 +16,11 @@
"""Library for converting generating fiddlers from diffs."""
import collections
+import dataclasses
import functools
import re
import types
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple
+from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple
from fiddle import daglish
from fiddle import diffing
@@ -31,12 +32,49 @@
import libcst as cst
[email protected](frozen=True)
+class ObjectToName:
+ prefix: str
+ path: daglish.Path
+
+ def __hash__(self):
+ return id(self)
+
+
+def assign_explicit_names(all_to_name: List[ObjectToName]) -> List[str]:
+ """Returns suggested names for a list of objects."""
+ return [
+ to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name
+ ]
+
+
+def assign_short_names(all_to_name: List[ObjectToName]) -> List[str]:
+ """Returns suggested names for a list of objects."""
+ name_to_paths = {}
+ for to_name in all_to_name:
+ sub_path = to_name.path[-1:]
+ name_to_paths.setdefault(
+ to_name.prefix + _path_to_name(sub_path), []
+ ).append(to_name)
+
+ result_as_dict = {}
+ for name, group in name_to_paths.items():
+ if len(group) == 1:
+ result_as_dict[group[0]] = name
+ else:
+ for to_name in group:
+ sub_path = to_name.path[-2:]
+ result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path)
+ return [result_as_dict[to_name] for to_name in all_to_name]
+
+
def fiddler_from_diff(
diff: diffing.Diff,
old: Any = None,
func_name: str = 'fiddler',
param_name: str = 'cfg',
import_manager: Optional[import_manager_lib.ImportManager] = None,
+ variable_naming: Literal['explicit', 'short'] = 'explicit',
):
"""Returns the CST for a fiddler function that applies the changes in `diff`.
@@ -72,6 +110,8 @@ def fiddler_from_diff(
import_manager: Existing import manager. Usually set to None, but if you are
integrating this with other code generation tasks, it can be nice to
share.
+ variable_naming: Whether to create intermediate variables with long,
+ explicit names, or just capture the last elements of a path.
Returns:
An `cst.Module` object. You can convert this to a string using
@@ -97,18 +137,26 @@ def fiddler_from_diff(
# ancestors) will be replaced by a change in the diff. If we don't have an
# `old` structure, then we pessimistically assume that we need to create
# variables for all used paths.
- moved_value_names = {}
+ moved_values_to_name = []
if old is not None:
modified_paths = set([change.target for change in diff.changes])
_add_path_aliases(modified_paths, old)
for path in sorted(used_paths, key=daglish.path_str):
if any(path[:i] in modified_paths for i in range(len(path) + 1)):
- moved_value_names[path] = namespace.get_new_name(
- _path_to_name(path), f'moved_{param_name}_')
+ moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path))
else:
for path in sorted(used_paths, key=daglish.path_str):
- moved_value_names[path] = namespace.get_new_name(
- _path_to_name(path), f'original_{param_name}_')
+ moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path))
+
+ if variable_naming == 'explicit':
+ initial_names = assign_explicit_names(moved_values_to_name)
+ else:
+ initial_names = assign_short_names(moved_values_to_name)
+
+ moved_value_names = {
+ to_name.path: namespace.get_new_name(name, prefix='')
+ for to_name, name in zip(moved_values_to_name, initial_names)
+ }
# Add variables for new shared values added by the diff.
new_shared_value_names = [