Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions pymini/pymini.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self):
self.reserved_names = set()
self.bindings = set()
self.loads = set()
self.external_bindings = set()

def visit_Name(self, node):
self.reserved_names.add(node.id)
Expand All @@ -120,6 +121,7 @@ def visit_arg(self, node):

def visit_Global(self, node):
self.reserved_names.update(node.names)
self.external_bindings.update(node.names)

visit_Nonlocal = visit_Global

Expand Down Expand Up @@ -245,7 +247,7 @@ def square(x):
0
"""
def visit_Expr(self, node):
if isinstance(node.value, ast.Constant):
if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str):
if len(node.parent.body) == 1: # if body is just the comment
return ast.parse('0').body[0] # replace comment with 0
return None # otherwise, remove comment
Expand Down Expand Up @@ -641,12 +643,16 @@ def _is_in_function_signature(self, node):
def _scope_bindings(self, node):
bindings = set()
globals_ = set()
nonlocals_ = set()
args = set()

class ScopeBindingCollector(ast.NodeVisitor):
def visit_Global(self, inner):
globals_.update(inner.names)

def visit_Nonlocal(self, inner):
nonlocals_.update(inner.names)

def visit_arg(self, inner):
args.add(inner.arg)
bindings.add(inner.arg)
Expand Down Expand Up @@ -687,7 +693,7 @@ def visit_GeneratorExp(self, inner):
collector.visit(statement)
continue
collector.visit(statement)
bindings.difference_update(globals_)
bindings.difference_update(globals_ | nonlocals_)
return {"bindings": bindings, "globals": globals_, "args": args}

def _is_preserved_public_global_reference(self, name):
Expand Down Expand Up @@ -730,7 +736,8 @@ def _local_scope_state(self, node):
for statement in getattr(node, "body", []):
collector.visit(statement)
reserved_names = set(collector.reserved_names)
for name in collector.loads - collector.bindings:
local_bindings = collector.bindings - collector.external_bindings
for name in collector.reserved_names - local_bindings:
visible_name = self._lookup_visible_identifier(name)
if visible_name is not None:
reserved_names.add(visible_name)
Expand Down Expand Up @@ -873,7 +880,9 @@ def visit_ClassDef(self, node):
self._generated_assignment(f"{node.name}.__qualname__ = {old_name!r}"),
]
return node
if node.name not in self.mapping_values:
if self.local_rename_scopes and not self._is_node_global(node):
node.name = self._rename_local_identifier(node.name)
elif node.name not in self.mapping_values:
Comment on lines +883 to +885
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Prevent nested class renames from shadowing global class metadata

Renaming non-global ClassDefs via _rename_local_identifier can reuse a renamed top-level class name, but class method/attribute metadata is keyed only by class identifier (class_method_argument_infos / class_member_mappings). When a nested class later overwrites that key, method-call rewriting for the top-level class can use the wrong signature and emit invalid calls (e.g., unconverted keywords), causing runtime errors under rename_arguments=True.

Useful? React with 👍 / 👎.

node.name = self._rename_identifier(node.name)
if parent_class_context is not None and old_name != node.name:
parent_class_context["member_mapping"][old_name] = node.name
Expand Down Expand Up @@ -1010,7 +1019,9 @@ def visit_FunctionDef(self, node):
self._pop_instance_scope()
self.local_rename_scopes.pop()
self.scope_stack.pop()
if node.name not in self.mapping_values:
if self.local_rename_scopes and not self._is_node_global(node):
node.name = self._rename_local_identifier(node.name)
elif node.name not in self.mapping_values:
Comment on lines +1022 to +1024
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep nested function names unique across module

Using _rename_local_identifier for non-global FunctionDefs allows a nested function to reuse a top-level renamed name (for example both becoming a). Because call-signature metadata is stored in self.callable_argument_infos by identifier, the later nested definition overwrites the top-level entry, so subsequent keyword calls to the top-level function are no longer rewritten correctly and can fail at runtime with TypeError when rename_arguments=True.

Useful? React with 👍 / 👎.

node.name = self._rename_identifier(node.name)
self.scope_stack.append(self._scope_bindings(node))
self.local_rename_scopes.append(self._local_scope_state(node))
Expand Down Expand Up @@ -1138,6 +1149,17 @@ def visit_ExceptHandler(self, node):
node.body = [self.visit(statement) for statement in node.body]
return node

def visit_Global(self, node):
node.names = [self.mapping.get(name, name) for name in node.names]
return node

def visit_Nonlocal(self, node):
node.names = [
self._lookup_local_identifier(name) or name
for name in node.names
]
return node

def visit_Call(self, node):
"""Apply renamed function names."""
node = self.generic_visit(node)
Expand Down Expand Up @@ -1347,7 +1369,7 @@ def renamed_module_name(module):
# imported references using the exporter module's mapping.
fused_mapping = {
value: value
for value in self.module_to_shortener[module].mapping.values()
for value in _reserved_names_in_node(tree)
}

imported = ImportedVariableShortener(
Expand Down Expand Up @@ -1513,6 +1535,18 @@ def _scope_mapping(self, node):
if count > 1 and len(repr(value)) > 4 and self._is_profitable(value, count, scope_type)
}

def _assignment_insertion_index(self, body):
insert_at = 0
if body and ast.get_docstring(ast.Module(body=body, type_ignores=[])) is not None:
insert_at = 1
while (
insert_at < len(body)
and isinstance(body[insert_at], ast.ImportFrom)
and body[insert_at].module == "__future__"
):
insert_at += 1
return insert_at

def _prepend_assignments(self, body, mapping):
assignments = []
for value, name in mapping.items():
Expand All @@ -1522,7 +1556,7 @@ def _prepend_assignments(self, body, mapping):
)
assignment._pymini_generated = True
assignments.append(assignment)
insert_at = 1 if body and ast.get_docstring(ast.Module(body=body, type_ignores=[])) is not None else 0
insert_at = self._assignment_insertion_index(body)
return body[:insert_at] + assignments + body[insert_at:]

def _append_cleanup(self, body, mapping):
Expand Down Expand Up @@ -2233,15 +2267,24 @@ def remove_extraneous_whitespace(self, source: str) -> str:
"""
import tokenize
from io import StringIO
keywords = set(keyword.kwlist)
keywords.update(getattr(keyword, "softkwlist", ()))
lines = []
for line in source.splitlines():
generated = list(tokenize.generate_tokens(StringIO(line).readline))
if any(
tokenize.tok_name.get(token.type, "").startswith("FSTRING_")
for token in generated
):
lines.append(line)
continue
tokens = []
last_token = None
for token in tokenize.generate_tokens(StringIO(line).readline):
for token in generated:
token = token.string
if token in keyword.kwlist and tokens and not any(last_token.endswith(c) for c in ':;= '):
if token in keywords and tokens and not any(last_token.endswith(c) for c in ':;= '):
tokens.append(token)
elif tokens and (last_token not in keyword.kwlist or token in ':;='):
elif tokens and (last_token not in keywords or token in ':;='):
tokens[-1] += token
else:
tokens.append(token)
Expand Down
Loading
Loading