Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions pymini/pymini.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import keyword
from graphlib import TopologicalSorter
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -41,39 +42,51 @@ class ReturnSimplifier(NodeTransformer):

return (some code)

NOTE: unused_names must be modified in-place, since the set is passed to
RemoveUnusedVariables at initialization. Can't return a new set.
NOTE: unused_assignments must be modified in-place, since the set is passed
to RemoveUnusedVariables at initialization. Can't return a new set.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name_to_node = {}
self.unused_names = set()
self.unused_assignments = set()

def visit_Assign(self, node: ast.Assign) -> ast.Assign:
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
self.name_to_node[node.targets[0].id] = node
return self.generic_visit(node)
def _can_simplify_return(self, previous: ast.stmt, current: ast.stmt) -> bool:
return (
isinstance(previous, ast.Assign)
and len(previous.targets) == 1
and isinstance(previous.targets[0], ast.Name)
and isinstance(current, ast.Return)
and isinstance(current.value, ast.Name)
and current.value.id == previous.targets[0].id
)

def visit_Return(self, node: ast.Return) -> ast.Return:
if isinstance(node.value, ast.Name):
self.unused_names.add(node.value.id)
node = self.name_to_node[node.value.id]
return ast.Return(value=node.value)
return self.generic_visit(node)
def _simplify_body(self, body: List[ast.stmt]) -> List[ast.stmt]:
for previous, current in zip(body, body[1:]):
if self._can_simplify_return(previous, current):
self.unused_assignments.add(id(previous))
current.value = copy.deepcopy(previous.value)
return body

def generic_visit(self, node):
node = super().generic_visit(node)
for field, value in ast.iter_fields(node):
if isinstance(value, list) and value and all(isinstance(item, ast.stmt) for item in value):
setattr(node, field, self._simplify_body(value))
return node


class RemoveUnusedVariables(NodeTransformer):
"""Remove all unused variables.

NOTE: cannot store a copy of unused_names, as this set is modified in-place
NOTE: cannot store a copy of unused_assignments, as this set is modified
in-place
after initialization.
"""
def __init__(self, unused_names: Set[str]):
def __init__(self, unused_assignments: Set[int]):
super().__init__()
self.unused_names = unused_names
self.unused_assignments = unused_assignments

def visit_Assign(self, node: ast.Name) -> ast.Name:
if isinstance(node.targets[0], ast.Name) and node.targets[0].id in self.unused_names:
def visit_Assign(self, node: ast.Assign) -> Optional[ast.Assign]:
if id(node) in self.unused_assignments:
return None
return self.generic_visit(node)

Expand Down Expand Up @@ -812,7 +825,7 @@ def minify(sources, modules='main', keep_module_names=False,

# simplify
simplifier := ReturnSimplifier(),
RemoveUnusedVariables(simplifier.unused_names),
RemoveUnusedVariables(simplifier.unused_assignments),

# minify
ParentSetter(),
Expand Down
32 changes: 32 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@ def f():
assert modules == ["main"]


def test_minify_does_not_crash_when_returning_parameter_names():
cleaned, modules = minify(
py(
"""
def abs_path(path):
if path:
return path

value = 1
return value
"""
),
"main",
keep_global_variables=True,
keep_module_names=True,
)

tree = ast.parse(cleaned[0])
function = next(node for node in tree.body if isinstance(node, ast.FunctionDef))
condition = function.body[0]
simplified_return = function.body[1]

assert isinstance(condition, ast.If)
assert isinstance(condition.body[0], ast.Return)
assert isinstance(condition.body[0].value, ast.Name)
assert condition.body[0].value.id == function.args.args[0].arg

assert isinstance(simplified_return, ast.Return)
assert isinstance(simplified_return.value, ast.Constant)
assert simplified_return.value.value == 1
assert modules == ["main"]

def test_minify_updates_cross_file_imports():
cleaned, modules = minify(
[
Expand Down
Loading