From b148e4392633ca5ca5c4447ac1ee286e8729da2c Mon Sep 17 00:00:00 2001
From: Alvin Wan
Date: Sat, 4 Apr 2026 03:20:14 -0700
Subject: [PATCH 1/2] support complex package graphs
---
README.md | 2 +-
pymini/pymini.py | 182 +++++++++++++++++++++++++++++++++-------------
tests/test_api.py | 28 +++----
tests/test_cli.py | 173 ++++++++++++++++++++++++++++++++++++++-----
4 files changed, 304 insertions(+), 81 deletions(-)
diff --git a/README.md b/README.md
index f1f2528..be1a618 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
## Status
-This project is maintained as an AST-based minifier for Python 3.9+ code. It is best suited to scripts and small module graphs that use straightforward imports such as `from module import name`.
+This project is maintained as an AST-based minifier for Python 3.9+ code. It is best suited to scripts and small-to-medium package graphs. Package mode preserves package layout and now covers relative imports, dotted imports, star imports, package re-exports, and `importlib`-based internal imports; bundle mode emits a self-contained loader-backed single file for the same kinds of graphs.
## Installation
diff --git a/pymini/pymini.py b/pymini/pymini.py
index f5086ed..d8f26a5 100644
--- a/pymini/pymini.py
+++ b/pymini/pymini.py
@@ -1,7 +1,6 @@
import ast
import copy
import keyword
-from graphlib import TopologicalSorter
from typing import Dict, List, Optional, Set
from .utils import variable_name_generator
@@ -169,6 +168,7 @@ def __init__(self, generator, mapping=None, modules=(), keep_global_variables=Fa
self.name_to_node = {}
self.nodes_to_insert = []
self.nodes_to_append = []
+ self.public_global_names = set()
# TODO: cleanup
self.str_name_to_node = {}
self.str_mapping = {}
@@ -296,7 +296,11 @@ def visit_Assign(self, node):
'demiurgic = 1\\nholy = demiurgic'
"""
if self.keep_global_variables and self._is_node_global(node): # TODO: rename but insert var def if worth it
- return self.generic_visit(node)
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ self.public_global_names.add(target.id)
+ node.value = self.visit(node.value)
+ return node
for target in node.targets:
if isinstance(target, ast.Name) and target.id not in self.mapping.values(): # TODO: make .values() more efficient
self.mapping[target.id] = target.id = next(self.generator)
@@ -334,6 +338,8 @@ def visit_Name(self, node):
"""
if node.id in self.mapping.values(): # TODO: make .values() more efficient
return node
+ if self.keep_global_variables and node.id in self.public_global_names:
+ return self.generic_visit(node)
if self.keep_global_variables and self._is_node_global(node):
if node.id in self.mapping:
node.id = self.mapping[node.id]
@@ -435,6 +441,7 @@ def __init__(self, generator, modules, module_to_shortener, keep_module_names=Fa
def transform(self, *trees):
original_modules = list(self.module_to_shortener)
+ packages = package_modules(original_modules)
module_to_module = {}
if not self.keep_module_names:
module_to_module = {module: next(self.generator) for module in original_modules}
@@ -455,9 +462,11 @@ def transform(self, *trees):
imported = ImportedVariableShortener(
self.generator,
mapping=fused_mapping,
+ current_module=module,
keep_global_variables=True,
module_to_module={_module: value for _module, value in module_to_module.items() if module != _module},
module_to_shortener={_module: value for _module, value in self.module_to_shortener.items() if module != _module},
+ packages=packages,
)
new_trees.extend(imported.transform(tree))
return new_trees
@@ -475,20 +484,25 @@ class ImportedVariableShortener(VariableShortener):
>>> apply('from silly import demiurgic, dontreplaceme; print(demiurgic)')
'from silly import a, dontreplaceme\\nprint(a)'
"""
- def __init__(self, *args, module_to_shortener=None, module_to_module=None, **kwargs):
+ def __init__(self, *args, current_module=None, module_to_shortener=None, module_to_module=None, packages=(), **kwargs):
super().__init__(*args, **kwargs)
+ self.current_module = current_module
self.module_to_shortener = module_to_shortener or {}
self.module_to_module = module_to_module or {}
+ self.packages = set(packages)
def visit_ImportFrom(self, node):
"""Apply shortener for imported module."""
- shortener = self.module_to_shortener.get(node.module, None)
+ module_name = resolve_import_from(self.current_module, node, self.packages)
+ shortener = self.module_to_shortener.get(module_name, None)
if shortener is not None:
for alias in node.names:
+ if alias.name == "*":
+ continue
if alias.name in shortener.mapping:
self.mapping[alias.name] = alias.name = shortener.mapping[alias.name]
- if node.module in self.module_to_module: # TODO: handle nested modules
- node.module = self.module_to_module[node.module]
+ if node.level == 0 and module_name in self.module_to_module:
+ node.module = self.module_to_module[module_name]
return self.generic_visit(node)
@@ -507,56 +521,35 @@ class FileFuser(Fuser):
Determine dependency between files by checking import statements. After
linearizing dependencies, combine files in that order.
"""
- def _dependencies_for_tree(self, tree, modules):
- dependencies = set()
+ def _dependencies_for_tree(self, module, tree, modules):
+ dependencies = ancestor_package_modules(module, modules)
+ packages = package_modules(modules)
for node in ast.walk(tree):
- if isinstance(node, ast.ImportFrom) and node.level == 0 and node.module in modules:
- dependencies.add(node.module)
+ if isinstance(node, ast.ImportFrom):
+ target_module = resolve_import_from(module, node, packages)
+ dependencies.update(internal_module_dependencies(target_module, modules))
+ for alias in node.names:
+ if alias.name != "*":
+ dependencies.update(
+ internal_module_dependencies(f"{target_module}.{alias.name}", modules)
+ )
elif isinstance(node, ast.Import):
for alias in node.names:
- module = alias.name.split('.')[0]
- if module in modules:
- dependencies.add(module)
+ dependencies.update(internal_module_dependencies(alias.name, modules))
return dependencies
- def _module_order(self, module_to_tree):
- sorter = TopologicalSorter()
- modules = set(module_to_tree)
- for module, tree in module_to_tree.items():
- dependencies = self._dependencies_for_tree(tree, modules - {module})
- sorter.add(module, *dependencies)
- return list(sorter.static_order())
-
- def _strip_internal_imports(self, tree, modules):
- filtered_body = []
- for statement in tree.body:
- if isinstance(statement, ast.ImportFrom) and statement.level == 0 and statement.module in modules:
- continue
- if isinstance(statement, ast.Import):
- statement.names = [
- alias for alias in statement.names
- if alias.name.split('.')[0] not in modules
- ]
- if not statement.names:
- continue
- filtered_body.append(statement)
- tree.body = filtered_body
- return tree
-
def transform(self, *trees):
module_to_tree = dict(zip(self.modules, trees))
- module_order = self._module_order(module_to_tree)
- internal_modules = set(module_order)
-
- combined_body = []
- for module in module_order:
- tree = self._strip_internal_imports(module_to_tree[module], internal_modules)
- combined_body.extend(tree.body)
-
- root = module_to_tree[module_order[0]]
- root.body = combined_body
- ast.fix_missing_locations(root)
- return [root]
+ modules = set(module_to_tree)
+ dependency_map = {
+ module: self._dependencies_for_tree(module, tree, modules - {module})
+ for module, tree in module_to_tree.items()
+ }
+ self.entry_modules = [
+ module for module in self.modules
+ if all(module not in dependencies for dependencies in dependency_map.values())
+ ] or list(self.modules)
+ return [module_to_tree[module] for module in self.modules]
def define_custom_variables(tree, mapping):
@@ -580,6 +573,95 @@ def transform(self, *trees):
yield ast.unparse(tree)
+def module_prefixes(module: Optional[str]) -> List[str]:
+ if not module:
+ return []
+ parts = module.split(".")
+ return [".".join(parts[:i]) for i in range(1, len(parts) + 1)]
+
+
+def package_modules(modules) -> Set[str]:
+ module_names = set(modules)
+ packages = set()
+ for module in module_names:
+ prefixes = module_prefixes(module)
+ packages.update(prefixes[:-1])
+ if any(other.startswith(f"{module}.") for other in module_names):
+ packages.add(module)
+ return packages
+
+
+def ancestor_package_modules(module: str, modules) -> Set[str]:
+ module_names = set(modules)
+ return {
+ prefix
+ for prefix in module_prefixes(module)[:-1]
+ if prefix in module_names
+ }
+
+
+def module_package_name(module: Optional[str], packages: Set[str]) -> str:
+ if not module:
+ return ""
+ if module in packages:
+ return module
+ return module.rsplit(".", 1)[0] if "." in module else ""
+
+
+def resolve_import_from(current_module: Optional[str], node: ast.ImportFrom, packages: Set[str]) -> Optional[str]:
+ if node.level == 0:
+ return node.module
+
+ package_name = module_package_name(current_module, packages)
+ package_parts = package_name.split(".") if package_name else []
+ if node.level > len(package_parts) + 1:
+ return node.module
+
+ base_parts = package_parts[:len(package_parts) - node.level + 1]
+ if node.module:
+ base_parts.extend(node.module.split("."))
+ return ".".join(part for part in base_parts if part)
+
+
+def internal_module_dependencies(module: Optional[str], modules) -> Set[str]:
+ module_names = set(modules)
+ return {
+ prefix
+ for prefix in module_prefixes(module)
+ if prefix in module_names
+ }
+
+
+def bundle_sources(sources: List[str], modules: List[str], entry_modules: Optional[List[str]] = None) -> str:
+ source_map = {module: source for module, source in zip(modules, sources)}
+ package_names = package_modules(source_map)
+ for package_name in sorted(package_names):
+ source_map.setdefault(package_name, "")
+ if not entry_modules:
+ entry_modules = list(modules)
+
+ bundle_runtime = f"""
+import importlib.abc as _a
+import importlib.util as _u
+import sys as _s
+_M={source_map!r}
+_P={sorted(package_names)!r}
+class _L(_a.Loader):
+ def __init__(self,n):self.n=n
+ def create_module(self,spec):return None
+ def exec_module(self,module):
+ if self.n in _P:module.__path__=[]
+ exec(_M[self.n],module.__dict__)
+class _F(_a.MetaPathFinder):
+ def find_spec(self,fullname,path=None,target=None):
+ if fullname not in _M:return None
+ return _u.spec_from_loader(fullname,_L(fullname),is_package=fullname in _P)
+_s.meta_path.insert(0,_F())
+for _m in {entry_modules!r}:__import__(_m)
+"""
+ return bundle_runtime.strip() + "\n"
+
+
class WhitespaceRemover(NodeTransformer):
"""Remove all whitespace.
@@ -856,6 +938,8 @@ def minify(sources, modules='main', keep_module_names=False,
WhitespaceRemover(),
)
cleaned = list(pipeline.transform(*trees))
+ if output_single_file:
+ cleaned = [bundle_sources(cleaned, fuser.modules, getattr(fuser, "entry_modules", None))]
output_modules = [single_file_module] if output_single_file else fuser.modules
return cleaned, output_modules
diff --git a/tests/test_api.py b/tests/test_api.py
index 787e476..9c1c351 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,4 +1,6 @@
import ast
+import subprocess
+import sys
from textwrap import dedent
from pymini import minify
@@ -50,18 +52,6 @@ def assert_bundle_preserves_public_alias(bundle_source: str) -> None:
assert call.args[0].func.id == function.name
-def assert_bundle_is_shortened(bundle_source: str) -> None:
- bundle_tree = ast.parse(bundle_source)
- function, printer = bundle_tree.body
-
- assert isinstance(function, ast.FunctionDef)
- assert function.name != "square"
- assert len(function.name) == 1
-
- call = printer.value
- assert call.args[0].func.id == function.name
-
-
def test_minify_simplifies_returns():
cleaned, modules = minify(
py(
@@ -166,7 +156,7 @@ def square(x):
assert modules == ["main", "side"]
-def test_minify_fuses_files_into_single_module():
+def test_minify_fuses_files_into_single_module(tmp_path):
cleaned, modules = minify(
[
py(
@@ -187,5 +177,15 @@ def square(x):
output_single_file=True,
)
- assert_bundle_is_shortened(cleaned[0])
+ bundle_path = tmp_path / "bundle.py"
+ bundle_path.write_text(cleaned[0], encoding="utf-8")
+ result = subprocess.run(
+ [sys.executable, str(bundle_path)],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert result.returncode == 0, result.stderr
+ assert result.stdout == "9\n"
assert modules == ["bundle"]
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 52c8c45..cae39a1 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,8 +1,10 @@
import ast
+import os
import subprocess
import sys
from pathlib import Path
from textwrap import dedent
+from typing import Optional
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@@ -18,6 +20,31 @@ def run_cli(*args: str) -> subprocess.CompletedProcess[str]:
)
+def run_python(code: str, *, pythonpath: Optional[Path] = None, cwd: Optional[Path] = None) -> subprocess.CompletedProcess[str]:
+ env = os.environ.copy()
+ if pythonpath is not None:
+ existing = env.get("PYTHONPATH")
+ env["PYTHONPATH"] = str(pythonpath) if not existing else f"{pythonpath}{os.pathsep}{existing}"
+ return subprocess.run(
+ [sys.executable, "-c", code],
+ cwd=cwd or PROJECT_ROOT,
+ env=env,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+
+def run_python_file(path: Path) -> subprocess.CompletedProcess[str]:
+ return subprocess.run(
+ [sys.executable, str(path)],
+ cwd=path.parent,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+
def py(source: str) -> str:
return dedent(source).strip() + "\n"
@@ -52,22 +79,6 @@ def assert_public_api_is_preserved(module_source: str, consumer_source: str) ->
assert call.args[1].func.id == function.name
-def assert_bundle_preserves_public_alias(bundle_source: str) -> None:
- bundle_tree = ast.parse(bundle_source)
- function, alias, printer = bundle_tree.body
-
- assert isinstance(function, ast.FunctionDef)
- assert function.name != "square"
- assert len(function.name) == 1
-
- assert isinstance(alias, ast.Assign)
- assert alias.targets[0].id == "square"
- assert alias.value.id == function.name
-
- call = printer.value
- assert call.args[0].func.id == function.name
-
-
def test_cli_accepts_directories(tmp_path):
source_dir = tmp_path / "src"
output_dir = tmp_path / "out"
@@ -127,7 +138,9 @@ def square(x):
result = run_cli("bundle", str(source_dir), "-o", str(bundle_path))
assert result.returncode == 0, result.stderr
- assert_bundle_preserves_public_alias(bundle_path.read_text(encoding="utf-8"))
+ execution = run_python_file(bundle_path)
+ assert execution.returncode == 0, execution.stderr
+ assert execution.stdout == "9\n"
def test_cli_preserves_nested_package_paths(tmp_path):
@@ -214,3 +227,129 @@ def test_cli_can_aggressively_rename_globals_in_package_mode(tmp_path):
assert isinstance(assignment, ast.Assign)
assert assignment.targets[0].id != "public_name"
assert len(assignment.targets[0].id) == 1
+
+
+def test_cli_package_mode_supports_relative_star_reexports(tmp_path):
+ source_dir = tmp_path / "src"
+ output_dir = tmp_path / "out"
+ pkg_dir = source_dir / "pkg"
+ pkg_dir.mkdir(parents=True)
+
+ write_py(
+ pkg_dir / "__init__.py",
+ """
+ from .helpers import *
+
+ __all__ = ["greet"]
+ """,
+ )
+ write_py(
+ pkg_dir / "helpers.py",
+ """
+ def greet():
+ return "hello"
+ """,
+ )
+ write_py(
+ source_dir / "app.py",
+ """
+ from pkg import greet
+
+ print(greet())
+ """,
+ )
+
+ result = run_cli("package", str(source_dir), "-o", str(output_dir))
+
+ assert result.returncode == 0, result.stderr
+ execution = run_python("import app", pythonpath=output_dir, cwd=tmp_path)
+ assert execution.returncode == 0, execution.stderr
+ assert execution.stdout == "hello\n"
+
+
+def test_cli_package_mode_supports_dotted_and_dynamic_imports(tmp_path):
+ source_dir = tmp_path / "src"
+ output_dir = tmp_path / "out"
+ pkg_dir = source_dir / "pkg"
+ pkg_dir.mkdir(parents=True)
+
+ write_py(pkg_dir / "__init__.py", "VALUE = 1")
+ write_py(
+ pkg_dir / "helpers.py",
+ """
+ def greet():
+ return "hello"
+ """,
+ )
+ write_py(
+ source_dir / "app.py",
+ """
+ import importlib
+ import pkg.helpers
+
+ print(pkg.helpers.greet(), importlib.import_module("pkg.helpers").greet())
+ """,
+ )
+
+ result = run_cli("package", str(source_dir), "-o", str(output_dir))
+
+ assert result.returncode == 0, result.stderr
+ execution = run_python("import app", pythonpath=output_dir, cwd=tmp_path)
+ assert execution.returncode == 0, execution.stderr
+ assert execution.stdout == "hello hello\n"
+
+
+def test_cli_bundle_mode_supports_complex_package_graphs(tmp_path):
+ source_dir = tmp_path / "src"
+ bundle_path = tmp_path / "bundle.py"
+ pkg_dir = source_dir / "pkg"
+ pkg_dir.mkdir(parents=True)
+
+ write_py(
+ pkg_dir / "__init__.py",
+ """
+ EVENTS = ["pkg"]
+
+ from .shared import register
+ from .helpers import *
+
+ register(EVENTS)
+ __all__ = ["EVENTS", "greet"]
+ """,
+ )
+ write_py(
+ pkg_dir / "shared.py",
+ """
+ def register(events):
+ events.append("shared")
+
+ def label():
+ return "hello"
+ """,
+ )
+ write_py(
+ pkg_dir / "helpers.py",
+ """
+ from .shared import label
+
+ def greet():
+ return label()
+ """,
+ )
+ write_py(
+ source_dir / "app.py",
+ """
+ from pkg import *
+ import importlib
+ import pkg.helpers
+
+ print(",".join(EVENTS), greet(), pkg.helpers.greet(), importlib.import_module("pkg.shared").label())
+ """,
+ )
+
+ result = run_cli("bundle", str(source_dir), "-o", str(bundle_path))
+
+ assert result.returncode == 0, result.stderr
+ execution = run_python_file(bundle_path)
+ assert execution.returncode == 0, execution.stderr
+ assert execution.stdout == "pkg,shared hello hello hello\n"
From 9988966b0d7268a72370919840ac60d597f07780 Mon Sep 17 00:00:00 2001
From: Alvin Wan
Date: Sat, 4 Apr 2026 03:35:10 -0700
Subject: [PATCH 2/2] address bundle review feedback
---
pymini/pymini.py | 106 +++++++++++++++++++++++++++++++++++++++++++---
tests/test_api.py | 61 ++++++++++++++++++++++++++
2 files changed, 161 insertions(+), 6 deletions(-)
diff --git a/pymini/pymini.py b/pymini/pymini.py
index d8f26a5..457444a 100644
--- a/pymini/pymini.py
+++ b/pymini/pymini.py
@@ -169,6 +169,7 @@ def __init__(self, generator, mapping=None, modules=(), keep_global_variables=Fa
self.nodes_to_insert = []
self.nodes_to_append = []
self.public_global_names = set()
+ self.scope_stack = []
# TODO: cleanup
self.str_name_to_node = {}
self.str_mapping = {}
@@ -190,6 +191,72 @@ def _append_public_alias(self, old_name, new_name):
if old_name != new_name:
self.nodes_to_append.append(ast.parse(f"{old_name} = {new_name}").body[0])
+ def _binding_names_from_target(self, target):
+ names = set()
+ if isinstance(target, ast.Name):
+ names.add(target.id)
+ elif isinstance(target, (ast.Tuple, ast.List)):
+ for element in target.elts:
+ names.update(self._binding_names_from_target(element))
+ return names
+
+ def _scope_bindings(self, node):
+ bindings = set()
+ globals_ = set()
+
+ class ScopeBindingCollector(ast.NodeVisitor):
+ def visit_Global(self, inner):
+ globals_.update(inner.names)
+
+ def visit_arg(self, inner):
+ bindings.add(inner.arg)
+
+ def visit_Name(self, inner):
+ if isinstance(inner.ctx, ast.Store):
+ bindings.add(inner.id)
+
+ def visit_FunctionDef(self, inner):
+ bindings.add(inner.name)
+
+ visit_AsyncFunctionDef = visit_FunctionDef
+
+ def visit_ClassDef(self, inner):
+ bindings.add(inner.name)
+
+ def visit_Lambda(self, inner):
+ return None
+
+ def visit_ListComp(self, inner):
+ return None
+
+ def visit_SetComp(self, inner):
+ return None
+
+ def visit_DictComp(self, inner):
+ return None
+
+ def visit_GeneratorExp(self, inner):
+ return None
+
+ collector = ScopeBindingCollector()
+ for statement in getattr(node, "body", []):
+ if isinstance(statement, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
+ collector.visit(statement)
+ continue
+ collector.visit(statement)
+ bindings.difference_update(globals_)
+ return {"bindings": bindings, "globals": globals_}
+
+ def _is_preserved_public_global_reference(self, name):
+ if name not in self.public_global_names:
+ return False
+ for scope in reversed(self.scope_stack):
+ if name in scope["globals"]:
+ continue
+ if name in scope["bindings"]:
+ return False
+ return True
+
def _visit_ImportOrImportFrom(self, node):
"""Shorten imported library names.
@@ -247,10 +314,18 @@ def visit_ClassDef(self, node):
old_name = node.name
node.name = self._rename_identifier(old_name)
self._append_public_alias(old_name, node.name)
- return self.generic_visit(node)
+ self.scope_stack.append(self._scope_bindings(node))
+ try:
+ return self.generic_visit(node)
+ finally:
+ self.scope_stack.pop()
if node.name not in self.mapping.values(): # TODO: make .values() more efficient
self.mapping[node.name] = node.name = next(self.generator)
- return self.generic_visit(node)
+ self.scope_stack.append(self._scope_bindings(node))
+ try:
+ return self.generic_visit(node)
+ finally:
+ self.scope_stack.pop()
def visit_FunctionDef(self, node):
"""Shorten function and argument names.
@@ -277,10 +352,18 @@ def visit_FunctionDef(self, node):
old_name = node.name
node.name = self._rename_identifier(old_name)
self._append_public_alias(old_name, node.name)
- return self.generic_visit(node)
+ self.scope_stack.append(self._scope_bindings(node))
+ try:
+ return self.generic_visit(node)
+ finally:
+ self.scope_stack.pop()
if node.name not in self.mapping.values(): # TODO: need to dedup this logic
self.mapping[node.name] = node.name = next(self.generator)
- return self.generic_visit(node)
+ self.scope_stack.append(self._scope_bindings(node))
+ try:
+ return self.generic_visit(node)
+ finally:
+ self.scope_stack.pop()
visit_AsyncFunctionDef = visit_FunctionDef
@@ -338,7 +421,7 @@ def visit_Name(self, node):
"""
if node.id in self.mapping.values(): # TODO: make .values() more efficient
return node
- if self.keep_global_variables and node.id in self.public_global_names:
+ if self.keep_global_variables and self._is_preserved_public_global_reference(node.id):
return self.generic_visit(node)
if self.keep_global_variables and self._is_node_global(node):
if node.id in self.mapping:
@@ -656,8 +739,19 @@ class _F(_a.MetaPathFinder):
def find_spec(self,fullname,path=None,target=None):
if fullname not in _M:return None
return _u.spec_from_loader(fullname,_L(fullname),is_package=fullname in _P)
+def _R(name,run_name):
+ spec=_u.spec_from_loader(run_name,_L(name),is_package=name in _P)
+ module=_u.module_from_spec(spec)
+ module.__name__=run_name
+ module.__package__=name if name in _P else name.rpartition('.')[0]
+ if name in _P:module.__path__=[]
+ _s.modules[run_name]=module
+ _s.modules.setdefault(name,module)
+ _L(name).exec_module(module)
+ return module
_s.meta_path.insert(0,_F())
-for _m in {entry_modules!r}:__import__(_m)
+if {entry_modules!r}:_R({entry_modules!r}[0],'__main__')
+for _m in {entry_modules!r}[1:]:__import__(_m)
"""
return bundle_runtime.strip() + "\n"
diff --git a/tests/test_api.py b/tests/test_api.py
index 9c1c351..d91d097 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -102,6 +102,39 @@ def abs_path(path):
assert simplified_return.value.value == 1
assert modules == ["main"]
+
+def test_minify_preserves_global_names_without_breaking_shadowed_locals(tmp_path):
+ cleaned, modules = minify(
+ py(
+ """
+ x = 1
+
+ def f():
+ x = 2
+ return x
+
+ print(f(), x)
+ """
+ ),
+ "main",
+ keep_global_variables=True,
+ keep_module_names=True,
+ )
+
+ module_path = tmp_path / "module.py"
+ module_path.write_text(cleaned[0], encoding="utf-8")
+ result = subprocess.run(
+ [sys.executable, str(module_path)],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert result.returncode == 0, result.stderr
+ assert result.stdout == "2 1\n"
+ assert modules == ["main"]
+
+
def test_minify_updates_cross_file_imports():
cleaned, modules = minify(
[
@@ -189,3 +222,31 @@ def square(x):
assert result.returncode == 0, result.stderr
assert result.stdout == "9\n"
assert modules == ["bundle"]
+
+
+def test_minify_bundle_runs_entry_module_as_main(tmp_path):
+ cleaned, modules = minify(
+ py(
+ """
+ if __name__ == "__main__":
+ print("ran")
+ """
+ ),
+ "main",
+ keep_global_variables=True,
+ keep_module_names=True,
+ output_single_file=True,
+ )
+
+ bundle_path = tmp_path / "bundle.py"
+ bundle_path.write_text(cleaned[0], encoding="utf-8")
+ result = subprocess.run(
+ [sys.executable, str(bundle_path)],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert result.returncode == 0, result.stderr
+ assert result.stdout == "ran\n"
+ assert modules == ["bundle"]