Have you ever felt overwhelmed by the complexities of torch.compile
? Diving into its workings can feel like black magic, with bytecode and Python internal details that many users fail to understand, hindering them from understanding and adapting to torch.compile
.
If you also face the problem, then you might be interested in depyf
. As the logo suggests, depyf
is a software tool to leverage advanced Python features (the Python snake symbol) to open up internal details (the internal gears symbol) of PyTorch's compiler torch.compile
(the PyTorch logo), so that users can understand it, adapt to it, and tune their code (the debugger symbol) to get maximum performance benefit out of it.
torch.compile
. Therefore, please use this project along with PyTorch nightly. Visit the PyTorch website for how to install nightly version of PyTorch. We recommend updating your PyTorch nightly installation every week or so.
depyf
?
Why If you want to understand bytecode generated by torch.compile
, then depyf
might be the only choice for you. Below we tested several existing decompilers, they struggle to decompile simple Python bytecode across versions, and have poor support for PyTorch.
Decompiler | Python 3.8 | Python 3.9 | Python 3.10 | Python 3.11 | PyTorch |
---|---|---|---|---|---|
decompyle3 | 90.6% (77/85) | Γ | Γ | Γ | Γ |
uncompyle6 | 91.8% (78/85) | Γ | Γ | Γ | Γ |
pycdc | 74.1% (63/85) | 74.1% (63/85) | 74.1% (63/85) | 67.1% (57/85) | 19.3% (27/140) |
depyf | 100% (85/85) | 100% (85/85) | 100% (85/85) | 100% (85/85) | 100% (140/140) |
Installation
Stable release: pip install depyf
Nightly version (recommended): pip install git+https://github.com/thuml/depyf.git
Usage
The main usage is quite simple: just wrap your code within a context manager:
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
- main()
+ import depyf
+ with depyf.prepare_debug("./dump_src_dir"):
+ main()
Then you can see all the details of torch.compile
inside the directory ./dump_src_dir
. The details are organized into the following:
full_code_for_xxx.py
for each function usingtorch.compile
__transformed_code_for_xxx.py
for Python code associated with each graph.__transformed_code_for_xxx.py.xxx_bytecode
for Python bytecode, dumped code object, can be loaded viadill.load(open("/path/to/file", "wb"))
. Note that theload
function might import some modules liketransformers
. Make sure you have these modules installed.__compiled_fn_xxx.py
for each computation graph and its optimization:Captured Graph
: a plain forward computation graphJoint Graph
: joint forward-backward graph fromAOTAutograd
Forward Graph
: forward graph fromAOTAutograd
Backward Graph
: backward graph fromAOTAutograd
kernel xxx
: compiled CPU/GPU kernel wrapper from Inductor.
We collect all the compilation artifacts when testing over 100 deep learning models. You can take a look to learn how the PyTorch compiler works.
If you want to use debugger to step through the above code, just add another context manager (and launch the script through debuggers):
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("./dump_src_dir"):
main()
+ with depyf.debug():
+ main()
Calling depyf.debug()
will pause the program for you to set breakpoints, and then you can use debuggers to hit breakpoints in these files under the ./dump_src_dir
directory you specified above.
Contact
If you have any question about depyf
, feel free to open issues to reach out! Any discussion/issue report/PR is welcome. Or contact [email protected] if you have any other questions.