From b148e4392633ca5ca5c4447ac1ee286e8729da2c Mon Sep 17 00:00:00 2001 From: Alvin Wan Date: Sat, 4 Apr 2026 03:20:14 -0700 Subject: [PATCH 1/2] support complex package graphs --- README.md | 2 +- pymini/pymini.py | 182 +++++++++++++++++++++++++++++++++------------- tests/test_api.py | 28 +++---- tests/test_cli.py | 173 ++++++++++++++++++++++++++++++++++++++----- 4 files changed, 304 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index f1f2528..be1a618 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## Status -This project is maintained as an AST-based minifier for Python 3.9+ code. It is best suited to scripts and small module graphs that use straightforward imports such as `from module import name`. +This project is maintained as an AST-based minifier for Python 3.9+ code. It is best suited to scripts and small-to-medium package graphs. Package mode preserves package layout and now covers relative imports, dotted imports, star imports, package re-exports, and `importlib`-based internal imports; bundle mode emits a self-contained loader-backed single file for the same kinds of graphs. ## Installation diff --git a/pymini/pymini.py b/pymini/pymini.py index f5086ed..d8f26a5 100644 --- a/pymini/pymini.py +++ b/pymini/pymini.py @@ -1,7 +1,6 @@ import ast import copy import keyword -from graphlib import TopologicalSorter from typing import Dict, List, Optional, Set from .utils import variable_name_generator @@ -169,6 +168,7 @@ def __init__(self, generator, mapping=None, modules=(), keep_global_variables=Fa self.name_to_node = {} self.nodes_to_insert = [] self.nodes_to_append = [] + self.public_global_names = set() # TODO: cleanup self.str_name_to_node = {} self.str_mapping = {} @@ -296,7 +296,11 @@ def visit_Assign(self, node): 'demiurgic = 1\\nholy = demiurgic' """ if self.keep_global_variables and self._is_node_global(node): # TODO: rename but insert var def if worth it - return self.generic_visit(node) + for target in node.targets: + if isinstance(target, ast.Name): + self.public_global_names.add(target.id) + node.value = self.visit(node.value) + return node for target in node.targets: if isinstance(target, ast.Name) and target.id not in self.mapping.values(): # TODO: make .values() more efficient self.mapping[target.id] = target.id = next(self.generator) @@ -334,6 +338,8 @@ def visit_Name(self, node): """ if node.id in self.mapping.values(): # TODO: make .values() more efficient return node + if self.keep_global_variables and node.id in self.public_global_names: + return self.generic_visit(node) if self.keep_global_variables and self._is_node_global(node): if node.id in self.mapping: node.id = self.mapping[node.id] @@ -435,6 +441,7 @@ def __init__(self, generator, modules, module_to_shortener, keep_module_names=Fa def transform(self, *trees): original_modules = list(self.module_to_shortener) + packages = package_modules(original_modules) module_to_module = {} if not self.keep_module_names: module_to_module = {module: next(self.generator) for module in original_modules} @@ -455,9 +462,11 @@ def transform(self, *trees): imported = ImportedVariableShortener( self.generator, mapping=fused_mapping, + current_module=module, keep_global_variables=True, module_to_module={_module: value for _module, value in module_to_module.items() if module != _module}, module_to_shortener={_module: value for _module, value in self.module_to_shortener.items() if module != _module}, + packages=packages, ) new_trees.extend(imported.transform(tree)) return new_trees @@ -475,20 +484,25 @@ class ImportedVariableShortener(VariableShortener): >>> apply('from silly import demiurgic, dontreplaceme; print(demiurgic)') 'from silly import a, dontreplaceme\\nprint(a)' """ - def __init__(self, *args, module_to_shortener=None, module_to_module=None, **kwargs): + def __init__(self, *args, current_module=None, module_to_shortener=None, module_to_module=None, packages=(), **kwargs): super().__init__(*args, **kwargs) + self.current_module = current_module self.module_to_shortener = module_to_shortener or {} self.module_to_module = module_to_module or {} + self.packages = set(packages) def visit_ImportFrom(self, node): """Apply shortener for imported module.""" - shortener = self.module_to_shortener.get(node.module, None) + module_name = resolve_import_from(self.current_module, node, self.packages) + shortener = self.module_to_shortener.get(module_name, None) if shortener is not None: for alias in node.names: + if alias.name == "*": + continue if alias.name in shortener.mapping: self.mapping[alias.name] = alias.name = shortener.mapping[alias.name] - if node.module in self.module_to_module: # TODO: handle nested modules - node.module = self.module_to_module[node.module] + if node.level == 0 and module_name in self.module_to_module: + node.module = self.module_to_module[module_name] return self.generic_visit(node) @@ -507,56 +521,35 @@ class FileFuser(Fuser): Determine dependency between files by checking import statements. After linearizing dependencies, combine files in that order. """ - def _dependencies_for_tree(self, tree, modules): - dependencies = set() + def _dependencies_for_tree(self, module, tree, modules): + dependencies = ancestor_package_modules(module, modules) + packages = package_modules(modules) for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.level == 0 and node.module in modules: - dependencies.add(node.module) + if isinstance(node, ast.ImportFrom): + target_module = resolve_import_from(module, node, packages) + dependencies.update(internal_module_dependencies(target_module, modules)) + for alias in node.names: + if alias.name != "*": + dependencies.update( + internal_module_dependencies(f"{target_module}.{alias.name}", modules) + ) elif isinstance(node, ast.Import): for alias in node.names: - module = alias.name.split('.')[0] - if module in modules: - dependencies.add(module) + dependencies.update(internal_module_dependencies(alias.name, modules)) return dependencies - def _module_order(self, module_to_tree): - sorter = TopologicalSorter() - modules = set(module_to_tree) - for module, tree in module_to_tree.items(): - dependencies = self._dependencies_for_tree(tree, modules - {module}) - sorter.add(module, *dependencies) - return list(sorter.static_order()) - - def _strip_internal_imports(self, tree, modules): - filtered_body = [] - for statement in tree.body: - if isinstance(statement, ast.ImportFrom) and statement.level == 0 and statement.module in modules: - continue - if isinstance(statement, ast.Import): - statement.names = [ - alias for alias in statement.names - if alias.name.split('.')[0] not in modules - ] - if not statement.names: - continue - filtered_body.append(statement) - tree.body = filtered_body - return tree - def transform(self, *trees): module_to_tree = dict(zip(self.modules, trees)) - module_order = self._module_order(module_to_tree) - internal_modules = set(module_order) - - combined_body = [] - for module in module_order: - tree = self._strip_internal_imports(module_to_tree[module], internal_modules) - combined_body.extend(tree.body) - - root = module_to_tree[module_order[0]] - root.body = combined_body - ast.fix_missing_locations(root) - return [root] + modules = set(module_to_tree) + dependency_map = { + module: self._dependencies_for_tree(module, tree, modules - {module}) + for module, tree in module_to_tree.items() + } + self.entry_modules = [ + module for module in self.modules + if all(module not in dependencies for dependencies in dependency_map.values()) + ] or list(self.modules) + return [module_to_tree[module] for module in self.modules] def define_custom_variables(tree, mapping): @@ -580,6 +573,95 @@ def transform(self, *trees): yield ast.unparse(tree) +def module_prefixes(module: Optional[str]) -> List[str]: + if not module: + return [] + parts = module.split(".") + return [".".join(parts[:i]) for i in range(1, len(parts) + 1)] + + +def package_modules(modules) -> Set[str]: + module_names = set(modules) + packages = set() + for module in module_names: + prefixes = module_prefixes(module) + packages.update(prefixes[:-1]) + if any(other.startswith(f"{module}.") for other in module_names): + packages.add(module) + return packages + + +def ancestor_package_modules(module: str, modules) -> Set[str]: + module_names = set(modules) + return { + prefix + for prefix in module_prefixes(module)[:-1] + if prefix in module_names + } + + +def module_package_name(module: Optional[str], packages: Set[str]) -> str: + if not module: + return "" + if module in packages: + return module + return module.rsplit(".", 1)[0] if "." in module else "" + + +def resolve_import_from(current_module: Optional[str], node: ast.ImportFrom, packages: Set[str]) -> Optional[str]: + if node.level == 0: + return node.module + + package_name = module_package_name(current_module, packages) + package_parts = package_name.split(".") if package_name else [] + if node.level > len(package_parts) + 1: + return node.module + + base_parts = package_parts[:len(package_parts) - node.level + 1] + if node.module: + base_parts.extend(node.module.split(".")) + return ".".join(part for part in base_parts if part) + + +def internal_module_dependencies(module: Optional[str], modules) -> Set[str]: + module_names = set(modules) + return { + prefix + for prefix in module_prefixes(module) + if prefix in module_names + } + + +def bundle_sources(sources: List[str], modules: List[str], entry_modules: Optional[List[str]] = None) -> str: + source_map = {module: source for module, source in zip(modules, sources)} + package_names = package_modules(source_map) + for package_name in sorted(package_names): + source_map.setdefault(package_name, "") + if not entry_modules: + entry_modules = list(modules) + + bundle_runtime = f""" +import importlib.abc as _a +import importlib.util as _u +import sys as _s +_M={source_map!r} +_P={sorted(package_names)!r} +class _L(_a.Loader): + def __init__(self,n):self.n=n + def create_module(self,spec):return None + def exec_module(self,module): + if self.n in _P:module.__path__=[] + exec(_M[self.n],module.__dict__) +class _F(_a.MetaPathFinder): + def find_spec(self,fullname,path=None,target=None): + if fullname not in _M:return None + return _u.spec_from_loader(fullname,_L(fullname),is_package=fullname in _P) +_s.meta_path.insert(0,_F()) +for _m in {entry_modules!r}:__import__(_m) +""" + return bundle_runtime.strip() + "\n" + + class WhitespaceRemover(NodeTransformer): """Remove all whitespace. @@ -856,6 +938,8 @@ def minify(sources, modules='main', keep_module_names=False, WhitespaceRemover(), ) cleaned = list(pipeline.transform(*trees)) + if output_single_file: + cleaned = [bundle_sources(cleaned, fuser.modules, getattr(fuser, "entry_modules", None))] output_modules = [single_file_module] if output_single_file else fuser.modules return cleaned, output_modules diff --git a/tests/test_api.py b/tests/test_api.py index 787e476..9c1c351 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,6 @@ import ast +import subprocess +import sys from textwrap import dedent from pymini import minify @@ -50,18 +52,6 @@ def assert_bundle_preserves_public_alias(bundle_source: str) -> None: assert call.args[0].func.id == function.name -def assert_bundle_is_shortened(bundle_source: str) -> None: - bundle_tree = ast.parse(bundle_source) - function, printer = bundle_tree.body - - assert isinstance(function, ast.FunctionDef) - assert function.name != "square" - assert len(function.name) == 1 - - call = printer.value - assert call.args[0].func.id == function.name - - def test_minify_simplifies_returns(): cleaned, modules = minify( py( @@ -166,7 +156,7 @@ def square(x): assert modules == ["main", "side"] -def test_minify_fuses_files_into_single_module(): +def test_minify_fuses_files_into_single_module(tmp_path): cleaned, modules = minify( [ py( @@ -187,5 +177,15 @@ def square(x): output_single_file=True, ) - assert_bundle_is_shortened(cleaned[0]) + bundle_path = tmp_path / "bundle.py" + bundle_path.write_text(cleaned[0], encoding="utf-8") + result = subprocess.run( + [sys.executable, str(bundle_path)], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0, result.stderr + assert result.stdout == "9\n" assert modules == ["bundle"] diff --git a/tests/test_cli.py b/tests/test_cli.py index 52c8c45..cae39a1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,10 @@ import ast +import os import subprocess import sys from pathlib import Path from textwrap import dedent +from typing import Optional PROJECT_ROOT = Path(__file__).resolve().parents[1] @@ -18,6 +20,31 @@ def run_cli(*args: str) -> subprocess.CompletedProcess[str]: ) +def run_python(code: str, *, pythonpath: Optional[Path] = None, cwd: Optional[Path] = None) -> subprocess.CompletedProcess[str]: + env = os.environ.copy() + if pythonpath is not None: + existing = env.get("PYTHONPATH") + env["PYTHONPATH"] = str(pythonpath) if not existing else f"{pythonpath}{os.pathsep}{existing}" + return subprocess.run( + [sys.executable, "-c", code], + cwd=cwd or PROJECT_ROOT, + env=env, + capture_output=True, + text=True, + check=False, + ) + + +def run_python_file(path: Path) -> subprocess.CompletedProcess[str]: + return subprocess.run( + [sys.executable, str(path)], + cwd=path.parent, + capture_output=True, + text=True, + check=False, + ) + + def py(source: str) -> str: return dedent(source).strip() + "\n" @@ -52,22 +79,6 @@ def assert_public_api_is_preserved(module_source: str, consumer_source: str) -> assert call.args[1].func.id == function.name -def assert_bundle_preserves_public_alias(bundle_source: str) -> None: - bundle_tree = ast.parse(bundle_source) - function, alias, printer = bundle_tree.body - - assert isinstance(function, ast.FunctionDef) - assert function.name != "square" - assert len(function.name) == 1 - - assert isinstance(alias, ast.Assign) - assert alias.targets[0].id == "square" - assert alias.value.id == function.name - - call = printer.value - assert call.args[0].func.id == function.name - - def test_cli_accepts_directories(tmp_path): source_dir = tmp_path / "src" output_dir = tmp_path / "out" @@ -127,7 +138,9 @@ def square(x): result = run_cli("bundle", str(source_dir), "-o", str(bundle_path)) assert result.returncode == 0, result.stderr - assert_bundle_preserves_public_alias(bundle_path.read_text(encoding="utf-8")) + execution = run_python_file(bundle_path) + assert execution.returncode == 0, execution.stderr + assert execution.stdout == "9\n" def test_cli_preserves_nested_package_paths(tmp_path): @@ -214,3 +227,129 @@ def test_cli_can_aggressively_rename_globals_in_package_mode(tmp_path): assert isinstance(assignment, ast.Assign) assert assignment.targets[0].id != "public_name" assert len(assignment.targets[0].id) == 1 + + +def test_cli_package_mode_supports_relative_star_reexports(tmp_path): + source_dir = tmp_path / "src" + output_dir = tmp_path / "out" + pkg_dir = source_dir / "pkg" + pkg_dir.mkdir(parents=True) + + write_py( + pkg_dir / "__init__.py", + """ + from .helpers import * + + __all__ = ["greet"] + """, + ) + write_py( + pkg_dir / "helpers.py", + """ + def greet(): + return "hello" + """, + ) + write_py( + source_dir / "app.py", + """ + from pkg import greet + + print(greet()) + """, + ) + + result = run_cli("package", str(source_dir), "-o", str(output_dir)) + + assert result.returncode == 0, result.stderr + execution = run_python("import app", pythonpath=output_dir, cwd=tmp_path) + assert execution.returncode == 0, execution.stderr + assert execution.stdout == "hello\n" + + +def test_cli_package_mode_supports_dotted_and_dynamic_imports(tmp_path): + source_dir = tmp_path / "src" + output_dir = tmp_path / "out" + pkg_dir = source_dir / "pkg" + pkg_dir.mkdir(parents=True) + + write_py(pkg_dir / "__init__.py", "VALUE = 1") + write_py( + pkg_dir / "helpers.py", + """ + def greet(): + return "hello" + """, + ) + write_py( + source_dir / "app.py", + """ + import importlib + import pkg.helpers + + print(pkg.helpers.greet(), importlib.import_module("pkg.helpers").greet()) + """, + ) + + result = run_cli("package", str(source_dir), "-o", str(output_dir)) + + assert result.returncode == 0, result.stderr + execution = run_python("import app", pythonpath=output_dir, cwd=tmp_path) + assert execution.returncode == 0, execution.stderr + assert execution.stdout == "hello hello\n" + + +def test_cli_bundle_mode_supports_complex_package_graphs(tmp_path): + source_dir = tmp_path / "src" + bundle_path = tmp_path / "bundle.py" + pkg_dir = source_dir / "pkg" + pkg_dir.mkdir(parents=True) + + write_py( + pkg_dir / "__init__.py", + """ + EVENTS = ["pkg"] + + from .shared import register + from .helpers import * + + register(EVENTS) + __all__ = ["EVENTS", "greet"] + """, + ) + write_py( + pkg_dir / "shared.py", + """ + def register(events): + events.append("shared") + + def label(): + return "hello" + """, + ) + write_py( + pkg_dir / "helpers.py", + """ + from .shared import label + + def greet(): + return label() + """, + ) + write_py( + source_dir / "app.py", + """ + from pkg import * + import importlib + import pkg.helpers + + print(",".join(EVENTS), greet(), pkg.helpers.greet(), importlib.import_module("pkg.shared").label()) + """, + ) + + result = run_cli("bundle", str(source_dir), "-o", str(bundle_path)) + + assert result.returncode == 0, result.stderr + execution = run_python_file(bundle_path) + assert execution.returncode == 0, execution.stderr + assert execution.stdout == "pkg,shared hello hello hello\n" From 9988966b0d7268a72370919840ac60d597f07780 Mon Sep 17 00:00:00 2001 From: Alvin Wan Date: Sat, 4 Apr 2026 03:35:10 -0700 Subject: [PATCH 2/2] address bundle review feedback --- pymini/pymini.py | 106 +++++++++++++++++++++++++++++++++++++++++++--- tests/test_api.py | 61 ++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 6 deletions(-) diff --git a/pymini/pymini.py b/pymini/pymini.py index d8f26a5..457444a 100644 --- a/pymini/pymini.py +++ b/pymini/pymini.py @@ -169,6 +169,7 @@ def __init__(self, generator, mapping=None, modules=(), keep_global_variables=Fa self.nodes_to_insert = [] self.nodes_to_append = [] self.public_global_names = set() + self.scope_stack = [] # TODO: cleanup self.str_name_to_node = {} self.str_mapping = {} @@ -190,6 +191,72 @@ def _append_public_alias(self, old_name, new_name): if old_name != new_name: self.nodes_to_append.append(ast.parse(f"{old_name} = {new_name}").body[0]) + def _binding_names_from_target(self, target): + names = set() + if isinstance(target, ast.Name): + names.add(target.id) + elif isinstance(target, (ast.Tuple, ast.List)): + for element in target.elts: + names.update(self._binding_names_from_target(element)) + return names + + def _scope_bindings(self, node): + bindings = set() + globals_ = set() + + class ScopeBindingCollector(ast.NodeVisitor): + def visit_Global(self, inner): + globals_.update(inner.names) + + def visit_arg(self, inner): + bindings.add(inner.arg) + + def visit_Name(self, inner): + if isinstance(inner.ctx, ast.Store): + bindings.add(inner.id) + + def visit_FunctionDef(self, inner): + bindings.add(inner.name) + + visit_AsyncFunctionDef = visit_FunctionDef + + def visit_ClassDef(self, inner): + bindings.add(inner.name) + + def visit_Lambda(self, inner): + return None + + def visit_ListComp(self, inner): + return None + + def visit_SetComp(self, inner): + return None + + def visit_DictComp(self, inner): + return None + + def visit_GeneratorExp(self, inner): + return None + + collector = ScopeBindingCollector() + for statement in getattr(node, "body", []): + if isinstance(statement, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + collector.visit(statement) + continue + collector.visit(statement) + bindings.difference_update(globals_) + return {"bindings": bindings, "globals": globals_} + + def _is_preserved_public_global_reference(self, name): + if name not in self.public_global_names: + return False + for scope in reversed(self.scope_stack): + if name in scope["globals"]: + continue + if name in scope["bindings"]: + return False + return True + def _visit_ImportOrImportFrom(self, node): """Shorten imported library names. @@ -247,10 +314,18 @@ def visit_ClassDef(self, node): old_name = node.name node.name = self._rename_identifier(old_name) self._append_public_alias(old_name, node.name) - return self.generic_visit(node) + self.scope_stack.append(self._scope_bindings(node)) + try: + return self.generic_visit(node) + finally: + self.scope_stack.pop() if node.name not in self.mapping.values(): # TODO: make .values() more efficient self.mapping[node.name] = node.name = next(self.generator) - return self.generic_visit(node) + self.scope_stack.append(self._scope_bindings(node)) + try: + return self.generic_visit(node) + finally: + self.scope_stack.pop() def visit_FunctionDef(self, node): """Shorten function and argument names. @@ -277,10 +352,18 @@ def visit_FunctionDef(self, node): old_name = node.name node.name = self._rename_identifier(old_name) self._append_public_alias(old_name, node.name) - return self.generic_visit(node) + self.scope_stack.append(self._scope_bindings(node)) + try: + return self.generic_visit(node) + finally: + self.scope_stack.pop() if node.name not in self.mapping.values(): # TODO: need to dedup this logic self.mapping[node.name] = node.name = next(self.generator) - return self.generic_visit(node) + self.scope_stack.append(self._scope_bindings(node)) + try: + return self.generic_visit(node) + finally: + self.scope_stack.pop() visit_AsyncFunctionDef = visit_FunctionDef @@ -338,7 +421,7 @@ def visit_Name(self, node): """ if node.id in self.mapping.values(): # TODO: make .values() more efficient return node - if self.keep_global_variables and node.id in self.public_global_names: + if self.keep_global_variables and self._is_preserved_public_global_reference(node.id): return self.generic_visit(node) if self.keep_global_variables and self._is_node_global(node): if node.id in self.mapping: @@ -656,8 +739,19 @@ class _F(_a.MetaPathFinder): def find_spec(self,fullname,path=None,target=None): if fullname not in _M:return None return _u.spec_from_loader(fullname,_L(fullname),is_package=fullname in _P) +def _R(name,run_name): + spec=_u.spec_from_loader(run_name,_L(name),is_package=name in _P) + module=_u.module_from_spec(spec) + module.__name__=run_name + module.__package__=name if name in _P else name.rpartition('.')[0] + if name in _P:module.__path__=[] + _s.modules[run_name]=module + _s.modules.setdefault(name,module) + _L(name).exec_module(module) + return module _s.meta_path.insert(0,_F()) -for _m in {entry_modules!r}:__import__(_m) +if {entry_modules!r}:_R({entry_modules!r}[0],'__main__') +for _m in {entry_modules!r}[1:]:__import__(_m) """ return bundle_runtime.strip() + "\n" diff --git a/tests/test_api.py b/tests/test_api.py index 9c1c351..d91d097 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -102,6 +102,39 @@ def abs_path(path): assert simplified_return.value.value == 1 assert modules == ["main"] + +def test_minify_preserves_global_names_without_breaking_shadowed_locals(tmp_path): + cleaned, modules = minify( + py( + """ + x = 1 + + def f(): + x = 2 + return x + + print(f(), x) + """ + ), + "main", + keep_global_variables=True, + keep_module_names=True, + ) + + module_path = tmp_path / "module.py" + module_path.write_text(cleaned[0], encoding="utf-8") + result = subprocess.run( + [sys.executable, str(module_path)], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0, result.stderr + assert result.stdout == "2 1\n" + assert modules == ["main"] + + def test_minify_updates_cross_file_imports(): cleaned, modules = minify( [ @@ -189,3 +222,31 @@ def square(x): assert result.returncode == 0, result.stderr assert result.stdout == "9\n" assert modules == ["bundle"] + + +def test_minify_bundle_runs_entry_module_as_main(tmp_path): + cleaned, modules = minify( + py( + """ + if __name__ == "__main__": + print("ran") + """ + ), + "main", + keep_global_variables=True, + keep_module_names=True, + output_single_file=True, + ) + + bundle_path = tmp_path / "bundle.py" + bundle_path.write_text(cleaned[0], encoding="utf-8") + result = subprocess.run( + [sys.executable, str(bundle_path)], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0, result.stderr + assert result.stdout == "ran\n" + assert modules == ["bundle"]