From b66d9d74ea7806cdf58c6afba8130351cd6b75a7 Mon Sep 17 00:00:00 2001 From: Fiddle-Config Team Date: Thu, 23 Apr 2026 11:26:40 -0700 Subject: [PATCH] Flip default value of allow_imports to False in Fiddle's absl_flags. This makes the flag secure by default, requiring explicit enablement to allow arbitrary imports. PiperOrigin-RevId: 904552214 --- fiddle/_src/absl_flags/flags.py | 4 +- fiddle/_src/absl_flags/sweep_flag.py | 286 ---------------------- fiddle/_src/absl_flags/sweep_flag_test.py | 248 ------------------- fiddle/_src/absl_flags/utils.py | 2 +- setup.py | 7 +- 5 files changed, 6 insertions(+), 541 deletions(-) delete mode 100644 fiddle/_src/absl_flags/sweep_flag.py delete mode 100644 fiddle/_src/absl_flags/sweep_flag_test.py diff --git a/fiddle/_src/absl_flags/flags.py b/fiddle/_src/absl_flags/flags.py index 0e956b27..af5cb864 100644 --- a/fiddle/_src/absl_flags/flags.py +++ b/fiddle/_src/absl_flags/flags.py @@ -118,7 +118,7 @@ def __init__( self, *args, default_module: Optional[types.ModuleType] = None, - allow_imports: bool = True, + allow_imports: bool = False, pyref_policy: Optional[serialization.PyrefPolicy] = None, **kwargs, ): @@ -289,7 +289,7 @@ def DEFINE_fiddle_config( # pylint: disable=invalid-name pyref_policy: Optional[serialization.PyrefPolicy] = None, flag_values: flags.FlagValues = flags.FLAGS, required: bool = False, - allow_imports: bool = True, + allow_imports: bool = False, ) -> flags.FlagHolder[Any]: r"""Declare and define a fiddle command line flag object. diff --git a/fiddle/_src/absl_flags/sweep_flag.py b/fiddle/_src/absl_flags/sweep_flag.py deleted file mode 100644 index 57532a9a..00000000 --- a/fiddle/_src/absl_flags/sweep_flag.py +++ /dev/null @@ -1,286 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fiddle-Config Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""A flag for specifying a 'sweep' of one or more fiddle configs to launch. - -This flag supports the config:, set:, and fiddler: commands of the -DEFINE_fiddle_config flag. (See documentation for DEFINE_fiddle_config for -more details on those commands.) - -Here we support an additional sweep: command, which allows you to specify -multiple configs by sweeping over any combination of: -* Arguments to the config function specified by config: -* Overrides to the resulting config. - -A sweep: command should specify a function call returning a list of -dictionaries, where each dictionary represents a single item in the sweep. -The entries in the dictionary are the overrides to apply, where keys can be of -the form: -* kwarg:foo -- to specify or override a keyword argument to the config function -* arg:0 -- to specify or override a positional argument to the config function -* path.to.field -- to specify an override to a field in the resulting config - returned by the config function. These paths follow the same format as is - accepted by set: commands and can take quite general forms like - foo.bar['baz'][0].boz. - Like set: commands, these overrides are applied via mutating the config. If a - single object is referenced in multiple places in your config, mutating it - in one place will affect everywhere it is referenced. If you make separate - copies of the same object in your config, you will need to mutate them all - separately. - If you prefer not to worry about this, consider using an argument to your - config function instead, to set the same value in multiple places. - -Multiple sweep: commands can be specified, which will result in taking the -product of the separate sweeps. - -An example showing off the functionality, which hopefully documents most of it: - -def model_config(hidden_size: int = 64, dropout_rate: float = 0.1): - return fdl.Config( - MyModel, - hidden_size=hidden_size, - dropout_rate=dropout_rate, - layer_norm=True, - log_level=logging.INFO, - submodule=fdl.Config( - Submodule, - hidden_size=hidden_size, - dropout_rate=dropout_rate, - num_layers=10, - log_level=logging.INFO, - ) - ) - -def hidden_size_sweep(): - return [ - # Overrides kwarg to the config: function: - {"kwarg:hidden_size": n} for n in [128, 256, 512] - ] - -def num_layers_sweep(max_layers: int = 5): - return [ - {"submodule.num_layers": n} for n in range(2, max_layers + 1) - ] - -def add_debug_logging(config): - config.log_level = logging.DEBUG - config.submodule.log_level = logging.DEBUG - -my_binary \ - --config 'config:model_config(dropout_rate=0.2)' \ - --config 'sweep:hidden_size_sweep' \ - --config 'sweep:num_layers_sweep(max_layers=10)' \ - --config 'set:layer_norm=True' \ - --config 'fiddler:add_debug_logging' \ - -This will apply the product of the two sweeps, varying both the arguments to -`model_config` and overriding the returned configs, as specified by the sweeps. -It then applies the set: and fiddler: commands (see fiddle.absl_flags) to all -the resulting configs. -""" - -import dataclasses -import itertools -import re -import types -from typing import Any, Mapping, Optional, Sequence, Tuple - -from absl import flags -from fiddle._src import config -from fiddle._src.absl_flags import utils -from fiddle._src.experimental import auto_config - - -_COMMAND_RE = re.compile(r"^(config|fiddler|set|sweep):(.+)$") -_KWARG_SWEEP_PREFIX = "kwarg:" -_ARG_SWEEP_PREFIX = "arg:" - - [email protected] -class SweepItem: - """A config, together with metadata. - - Attributes: - config: A fdl.Buildable - overrides_applied: A dictionary of overrides applied -- either to arguments - of the config function or the config itself. Useful as metadata to - distinguish this from other items in the sweep. - """ - - config: config.Buildable - overrides_applied: Mapping[str, Any] - - -class DEFINE_fiddle_sweep: # pylint: disable=invalid-name - """Defines a flag for a sweep of one or more fiddle configs. - - Its .value property returns a list of SweepItem objects, each a config paired - with the overrides applied to it. - - (While it has a .value like the FlagHolders returned by flags.DEFINE_*, it's - not a 'real' FlagHolder, just a wrapper around a DEFINE_multi_string flag. - This means the value appearing under flags.FLAGS will just be the strings from - the multi_string flag.) - """ - - # Note flags.FiddleFlag goes through some contortions to subclass - # flags.MultiFlag while ensuring lazy parsing, which the superclass is not - # really built for. This approach of wrapping a plain multi-string flag is - # simpler, but has the (mild) downside stated above. - - def __init__( - self, - name: str, - required: bool = False, - help: str = "Multi-flag for a fiddle config sweep.", # pylint: disable=redefined-builtin - default_module: Optional[types.ModuleType] = None, - allow_imports: bool = True, - ): - self.name = name - self._allow_imports = allow_imports - self._default_module = default_module - self._multi_flag = flags.DEFINE_multi_string( - name=name, - default=None, - help=help, - required=required, - ) - self._value = None - - @property - def value(self) -> Sequence[SweepItem]: - if self._value is None: - self._value = self._parse(self._multi_flag.value) - return self._value - - def _parse_call_expression( - self, expression: str, mode: utils.ImportDottedNameDebugContext - ): - """Parses a call expression as supported by fiddle.absl_flags.""" - call_expr = utils.CallExpression.parse(expression) - base_name = call_expr.func_name - base_fn = utils.resolve_function_reference( - base_name, - mode, - self._default_module, - self._allow_imports, - failure_msg_prefix="Could not resolve reference from fiddle sweep_flag", - ) - if auto_config.is_auto_config(base_fn): - base_fn = base_fn.as_buildable - return base_fn, call_expr.args, call_expr.kwargs - - def _parse_command(self, item: str): - match = _COMMAND_RE.fullmatch(item) - if not match: - raise ValueError( - f"All flag values to {self.name} must begin with 'config:', " - "'set:', 'fiddler:' or 'sweep:'." - ) - command, expression = match.groups() - return command, expression - - def _parse(self, commands: Sequence[str]) -> Sequence[SweepItem]: - """Parse a sequence of commands describing a config sweep.""" - - command_type, expression = self._parse_command(commands[0]) - if command_type != "config": - raise ValueError( - 'First command must be a "config:" command, specifying a call to a ' - "function returning a config. Got: " - + commands[0] - ) - - config_fn, args, kwargs = self._parse_call_expression( - expression, mode=utils.ImportDottedNameDebugContext.BASE_CONFIG - ) - - sweeps = [] - non_sweep_commands = [] - for command in commands[1:]: - command_type, expr = self._parse_command(command) - if command_type == "sweep": - sweep_fn, sweep_args, sweep_kwargs = self._parse_call_expression( - expr, mode=utils.ImportDottedNameDebugContext.SWEEP - ) - sweep = sweep_fn(*sweep_args, **sweep_kwargs) - sweeps.append(sweep) - else: - non_sweep_commands.append((command_type, expr)) - - # Take product of all sweeps: - sweep = [_merge_dicts(dicts) for dicts in itertools.product(*sweeps)] - - configs = [ - _apply_sweep_overrides(config_fn, args, kwargs, overrides) - for overrides in sweep - ] - configs = self._apply_sets_and_fiddlers(configs, non_sweep_commands) - return [ - SweepItem(config=config, overrides_applied=overrides) - for config, overrides in zip(configs, sweep) - ] - - def _apply_sets_and_fiddlers( - self, - configs: Sequence[config.Buildable], - commands: Sequence[Tuple[str, str]], - ) -> Sequence[config.Buildable]: - """Apply set: and fiddler: commands to multiple configs. May mutate.""" - for command_type, expr in commands: - if command_type == "set": - for cfg in configs: - utils.set_value(cfg, expr) - elif command_type == "fiddler": - fiddler, args, kwargs = self._parse_call_expression( - expr, mode=utils.ImportDottedNameDebugContext.FIDDLER - ) - new_configs = [] - for cfg in configs: - new_cfg = fiddler(cfg, *args, **kwargs) - if new_cfg is None: - new_cfg = cfg # Fiddler mutated it. - new_configs.append(new_cfg) - configs = new_configs - else: - raise ValueError( - f"Unexpected command type {command_type} in {self.name}." - ) - return configs - - -def _apply_sweep_overrides(config_fn, args, kwargs, overrides): - """Apply overrides to the args/kwargs and the result of a config function.""" - kwargs = dict(kwargs) - args = list(args) - config_overrides = {} - for key, value in overrides.items(): - if key.startswith(_KWARG_SWEEP_PREFIX): - key = key[len(_KWARG_SWEEP_PREFIX) :] - kwargs[key] = value - elif key.startswith(_ARG_SWEEP_PREFIX): - index = int(key[len(_ARG_SWEEP_PREFIX) :]) - # Pad long enough to accept the new arg: - args.extend([None] * (index + 1 - len(args))) - args[index] = value - else: - config_overrides[key] = value - - cfg = config_fn(*args, **kwargs) - return utils.with_overrides(cfg, config_overrides) - - -def _merge_dicts(dicts: Sequence[Mapping[Any, Any]]) -> Mapping[Any, Any]: - return {k: v for d in dicts for k, v in d.items()} # pylint: disable=g-complex-comprehension diff --git a/fiddle/_src/absl_flags/sweep_flag_test.py b/fiddle/_src/absl_flags/sweep_flag_test.py deleted file mode 100644 index 235f38c5..00000000 --- a/fiddle/_src/absl_flags/sweep_flag_test.py +++ /dev/null @@ -1,248 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fiddle-Config Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import dataclasses -import sys - -from absl.testing import absltest -from fiddle._src import config -from fiddle._src.absl_flags import sweep_flag -from fiddle._src.experimental import auto_config - - -# We'll just test the _parse method, the rest is relatively trivial wrapper. -FLAG = sweep_flag.DEFINE_fiddle_sweep( - "_dummy", default_module=sys.modules[__name__], allow_imports=False -) - - [email protected] -class Foo: - x: str - y: int - z: int = 0 - - -def get_config(arg: str = "arg", kwarg: int = 1): - return config.Config(Foo, x=arg, y=kwarg) - - -@auto_config.auto_config -def get_config_auto(kwarg: int = 1): - return Foo(x="foo", y=kwarg) - - -def fiddler(cfg): - cfg.y *= 2 - - -def arg_kwarg_sweep(): - return [{"arg:0": "foo"}, {"kwarg:kwarg": 3}] - - -def override_sweep(y=3): - return [{"x": "foo"}, {"y": y}] - - -def x_sweep(): - return [{"x": "foo"}, {"x": "bar"}] - - -def y_sweep(): - return [{"y": 2}, {"y": 3}] - - -def kwarg_sweep(): - return [{"kwarg:kwarg": 2}, {"kwarg:kwarg": 3}] - - -class FiddleSweepFlagTest(absltest.TestCase): - - def assert_single_config_sweep(self, sweep, expected): - self.assertEqual( - sweep, [sweep_flag.SweepItem(config=expected, overrides_applied={})] - ) - - def test_plain_config(self): - sweep = FLAG._parse(["config:get_config"]) - expected = get_config() - self.assert_single_config_sweep(sweep, expected) - - def test_config_with_args(self): - sweep = FLAG._parse(["config:get_config('foo', kwarg=2)"]) - expected = get_config("foo", kwarg=2) - self.assert_single_config_sweep(sweep, expected) - - def test_auto_config_with_args(self): - sweep = FLAG._parse(["config:get_config_auto(kwarg=2)"]) - expected = get_config_auto.as_buildable(kwarg=2) - self.assert_single_config_sweep(sweep, expected) - - def test_set(self): - sweep = FLAG._parse(["config:get_config", "set:x='bar'"]) - expected = get_config() - expected.x = "bar" - self.assert_single_config_sweep(sweep, expected) - - def test_fiddler(self): - sweep = FLAG._parse(["config:get_config", "fiddler:fiddler"]) - expected = get_config() - fiddler(expected) - self.assert_single_config_sweep(sweep, expected) - - def test_arg_kwarg_sweep(self): - sweep = FLAG._parse(["config:get_config", "sweep:arg_kwarg_sweep"]) - expected = [ - sweep_flag.SweepItem( - config=get_config("foo"), - overrides_applied={"arg:0": "foo"}, - ), - sweep_flag.SweepItem( - config=get_config(kwarg=3), - overrides_applied={"kwarg:kwarg": 3}, - ), - ] - self.assertEqual(sweep, expected) - - def test_arg_kwarg_sweep_overriding_existing(self): - sweep = FLAG._parse( - ["config:get_config('bar', kwarg=2)", "sweep:arg_kwarg_sweep"] - ) - expected = [ - sweep_flag.SweepItem( - config=get_config("foo", kwarg=2), - overrides_applied={"arg:0": "foo"}, - ), - sweep_flag.SweepItem( - config=get_config("bar", kwarg=3), - overrides_applied={"kwarg:kwarg": 3}, - ), - ] - self.assertEqual(sweep, expected) - - def test_override_sweep_and_sweep_arg(self): - sweep = FLAG._parse(["config:get_config", "sweep:override_sweep(y=4)"]) - parent_config = get_config() - config1 = copy.copy(parent_config) - config1.x = "foo" - config2 = copy.copy(parent_config) - config2.y = 4 - expected = [ - sweep_flag.SweepItem( - config=config1, - overrides_applied={"x": "foo"}, - ), - sweep_flag.SweepItem( - config=config2, - overrides_applied={"y": 4}, - ), - ] - self.assertEqual(sweep, expected) - - def test_product_sweep(self): - sweep = FLAG._parse(["config:get_config", "sweep:x_sweep", "sweep:y_sweep"]) - expected = [ - sweep_flag.SweepItem( - config=config.Config(Foo, x="foo", y=2), - overrides_applied={"x": "foo", "y": 2}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="foo", y=3), - overrides_applied={"x": "foo", "y": 3}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="bar", y=2), - overrides_applied={"x": "bar", "y": 2}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="bar", y=3), - overrides_applied={"x": "bar", "y": 3}, - ), - ] - self.assertEqual(sweep, expected) - - def test_set_and_override_sweep(self): - sweep = FLAG._parse([ - "config:get_config", - "set:y=5", - "sweep:x_sweep", - ]) - parent_config = get_config() - parent_config.y = 5 - expected = [ - sweep_flag.SweepItem( - config=config.Config(Foo, x="foo", y=5), - overrides_applied={"x": "foo"}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="bar", y=5), - overrides_applied={"x": "bar"}, - ), - ] - self.assertEqual(sweep, expected) - - def test_kwarg_sweep_and_set(self): - sweep = FLAG._parse([ - "config:get_config", - "sweep:kwarg_sweep", - "set:x='foo'", - ]) - config1 = get_config(kwarg=2) - config1.x = "foo" - config2 = get_config(kwarg=3) - config2.x = "foo" - expected = [ - sweep_flag.SweepItem( - config=config1, - overrides_applied={"kwarg:kwarg": 2}, - ), - sweep_flag.SweepItem( - config=config2, - overrides_applied={"kwarg:kwarg": 3}, - ), - ] - self.assertEqual(sweep, expected) - - def test_shebang(self): - sweep = FLAG._parse([ - "config:get_config", - "sweep:x_sweep", - "sweep:kwarg_sweep", - "set:z=10", - ]) - expected = [ - sweep_flag.SweepItem( - config=config.Config(Foo, x="foo", y=2, z=10), - overrides_applied={"x": "foo", "kwarg:kwarg": 2}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="foo", y=3, z=10), - overrides_applied={"x": "foo", "kwarg:kwarg": 3}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="bar", y=2, z=10), - overrides_applied={"x": "bar", "kwarg:kwarg": 2}, - ), - sweep_flag.SweepItem( - config=config.Config(Foo, x="bar", y=3, z=10), - overrides_applied={"x": "bar", "kwarg:kwarg": 3}, - ), - ] - self.assertEqual(sweep, expected) - - -if __name__ == "__main__": - absltest.main() diff --git a/fiddle/_src/absl_flags/utils.py b/fiddle/_src/absl_flags/utils.py index 1a41782b..00a224b4 100644 --- a/fiddle/_src/absl_flags/utils.py +++ b/fiddle/_src/absl_flags/utils.py @@ -289,7 +289,7 @@ def resolve_function_reference( def init_config_from_expression( expression: str, module: Optional[types.ModuleType] = None, - allow_imports: bool = True, + allow_imports: bool = False, ) -> config.Buildable: """Initializes a `fdl.Buildable` from a function call expression. diff --git a/setup.py b/setup.py index be49b022..0a7ce580 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,7 @@ # pyformat: disable import sys -from setuptools import find_packages -from setuptools import setup +import setuptools _dct = {} @@ -42,11 +41,11 @@ """ # pylint: disable=g-long-ternary -setup( +setuptools.setup( name='fiddle', version=__version__, include_package_data=True, - packages=find_packages(exclude=['docs']), # Required + packages=setuptools.find_packages(exclude=['docs']), # Required package_data={'testdata': ['testdata/*.fiddle']}, python_requires='>=3.8', install_requires=[