Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[run]
source = depyf
# omit patched file from pytorch
omit = depyf/explain/patched*

[report]
include = depyf/*
2 changes: 2 additions & 0 deletions .github/workflows/test_decompile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ jobs:
run: |
pytest --cov=depyf tests/test.py
coverage run --append python_coverage.py
coverage run --append tests/test_code_owner.py
coverage run --append tests/test_ensure.py
python tests/assert.py

- name: Upload results to Codecov
Expand Down
11 changes: 0 additions & 11 deletions depyf/code_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,6 @@ def visit_FunctionDef(self, node):
# return self.generic_visit(node)


def structure_hash(source_code: str) -> str:
"""Compute the hash of code structure, ignore the function name difference.
This is because PyTorch dynamically generates function names.
"""
tree = ast.parse(source_code)
tree = IdentifierReplacer().visit(tree)
modified_code = astor.to_source(tree)
hash_value = hashlib.md5(modified_code.encode()).hexdigest()
return hash_value


def fix_irregular_code(
old_bytecode: CodeType,
src_code: str,
Expand Down
3 changes: 2 additions & 1 deletion depyf/decompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,8 @@ def cleanup_instructions(code, instructions: List[Instruction]):

def __init__(self, code: Union[CodeType, Callable]):
if callable(code):
code = code.__code__
from depyf.utils import get_code_owner
code = get_code_owner(code).__code__
self.code = code
instructions = list(convert_instruction(_)
for _ in dis.get_instructions(code))
Expand Down
10 changes: 0 additions & 10 deletions depyf/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ def _extract_artifacts(original_code: CodeType, module):
result = DynamoOptimizationResult(original_code, None, module)
return result


def _collect_compiled_subgraphs(result: DynamoOptimizationResult):
compiled_subgraphs = {
entry.compiled_subgraph_proxy.name: entry.compiled_subgraph for entry in result.compiled_code_entries}
for entry in result.compiled_code_entries:
for func in entry.referenced_global_functions.values():
ans = _collect_compiled_subgraphs(func)
compiled_subgraphs.update(ans)
return compiled_subgraphs

def dump_src(original_code: CodeType, module):
from depyf.explain.global_variables import data
assert data["is_inside_prepare_debug"], "`dump_src` must be used inside `depyf.prepare_debug`."
Expand Down
14 changes: 5 additions & 9 deletions depyf/explain/enable_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,14 @@ def __call__(self, code, new_code):
import dill
# code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`.
# see https://github.com/pytorch/pytorch/issues/116013 for more details.
try:
with contextlib.suppress(Exception):
dill.dump(code, open(filename + ".original_bytecode", "wb"))
except:
pass
try:

with contextlib.suppress(Exception):
dill.dump(new_code, open(filename + ".transformed_bytecode", "wb"))
except:
pass
try:

with contextlib.suppress(Exception):
dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb"))
except:
pass

# this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487
from torch._dynamo.utils import orig_code_map
Expand Down
2 changes: 1 addition & 1 deletion depyf/explain/patched___call__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def patched___call__(self, code, check_fn):
from depyf.explain.global_variables import data
from depyf.explain.utils import get_code_owner
from depyf.utils import get_code_owner
import torch
unpatched___call__ = data["unpatched___call__"]
optimized_functions = data["optimized_functions"]
Expand Down
3 changes: 2 additions & 1 deletion depyf/explain/patched_lazy_format_graph_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
from depyf.explain.utils import get_current_compiled_fn_name, get_code_owner, write_code_to_file_template
from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template
from depyf.utils import get_code_owner
func_name = get_current_compiled_fn_name()
file_name = name if name != func_name else "Captured Graph"
file_name = func_name + " " + file_name
Expand Down
47 changes: 1 addition & 46 deletions depyf/explain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,6 @@
from dataclasses import dataclass
import contextlib

import depyf
from depyf.decompiler import DecompilationError
from depyf.utils import get_function_signature


def decompile_ensure(fn, overwite_fn_name=None):
try:
decompiled_source_code = depyf.Decompiler(
fn).decompile(overwite_fn_name=overwite_fn_name)
except DecompilationError as e:
header = get_function_signature(fn, overwite_fn_name=overwite_fn_name)
decompiled_source_code = header + " 'Failed to decompile.'\n"
return decompiled_source_code


class CodeProxy:
instances: Dict[str, "CodeProxy"] = {}
used_instances: Set[str] = set()
Expand All @@ -49,6 +34,7 @@ def consume_new_name(name: str):

@staticmethod
def decompile_with_name(code: CodeType, name: str, skip_decompile=False):
from depyf.utils import decompile_ensure
if hasattr(code, "__code__"):
code = code.__code__
if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"):
Expand Down Expand Up @@ -320,37 +306,6 @@ def write_code_to_file_template(src, path_template):
return new_filepath


def get_code_owner(fn):
"""A callable object `fn` might have a __code__ attribute, which is a code object.
However, `fn` might not be the owner of the code object. Only the code owner can change the code object.
This function returns the owner of the code object.
An example:
class A:
def func(self):
return 1
a = A()
`a.func.__code__` is read-only. `A.func.__code__` is writable.
We can change the code object via `a.func.__func__.__code__`.
"""
import functools
while True:
if hasattr(fn, "__func__"):
# deal with bounded function
fn = fn.__func__
elif hasattr(fn, "__wrapped__"):
# deal with lru_cache or other decorators
fn = fn.__wrapped__
elif isinstance(fn, functools.partial):
# deal with partial function
fn = fn.func
elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"):
# deal with callable object
fn = fn.__call__.__func__
else:
break
return fn


def get_current_compiled_fn_name():
import torch
from torch._dynamo.bytecode_transformation import _unique_id_counter
Expand Down
45 changes: 45 additions & 0 deletions depyf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,48 @@ def safe_create_directory(path):
except OSError as e:
if not os.path.isdir(path):
raise



def get_code_owner(fn):
"""A callable object `fn` might have a __code__ attribute, which is a code object.
However, `fn` might not be the owner of the code object. Only the code owner can change the code object.
This function returns the owner of the code object.
An example:
class A:
def func(self):
return 1
a = A()
`a.func.__code__` is read-only. `A.func.__code__` is writable.
We can change the code object via `a.func.__func__.__code__`.
"""
import functools
while True:
if hasattr(fn, "__func__"):
# deal with bounded function
fn = fn.__func__
elif hasattr(fn, "__wrapped__"):
# deal with lru_cache or other decorators
fn = fn.__wrapped__
elif isinstance(fn, functools.partial):
# deal with partial function
fn = fn.func
elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"):
# deal with callable object
fn = fn.__call__.__func__
else:
break
return fn



def decompile_ensure(fn: CodeType, overwite_fn_name=None):
import depyf
from depyf.decompiler import DecompilationError
try:
decompiled_source_code = depyf.Decompiler(
fn).decompile(overwite_fn_name=overwite_fn_name)
except DecompilationError as e:
header = get_function_signature(fn, overwite_fn_name=overwite_fn_name)
decompiled_source_code = header + " 'Failed to decompile.'\n"
return decompiled_source_code
16 changes: 16 additions & 0 deletions tests/test_code_owner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from functools import partial, lru_cache

def f(a, b):
return a + b

class A:
def __call__(self, a, b):
return a + b

import depyf

print(depyf.decompile(partial(f, 1)))

print(depyf.decompile(lru_cache(None)(f)))

print(depyf.decompile(A()))
11 changes: 11 additions & 0 deletions tests/test_ensure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from depyf.utils import decompile_ensure

import asyncio

def f(a, b):
try:
return a + b
finally:
return a - b

print(decompile_ensure(f.__code__))
2 changes: 1 addition & 1 deletion tests/test_pytorch/test_simple_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ def fn():
return x.grad

import depyf
with depyf.prepare_debug("./simple_output"):
with depyf.prepare_debug("./simple_output", log_bytecode=True, clean_wild_fx_code=False):
fn()