'PyTorch using custom C++ function. JIT compilation seems to skip C++ call

I will try to explain my problem providing as much context as possible, the system I am using is too complex to build a minimal reproducible example.

I am using Python and PyTorch. My program is a physical simulator: https://github.com/tum-pbs/PhiFlow. The simulation is computed frame after frame. Each frame calls a custom C++ function (a CG-solver). I have linked this function using Pybind11 like so. Additionally I have also registered a custom torchscript operator (https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#registering-the-custom-operator-with-torchscript):

    PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
      m.def("conjugate_gradient", &conjugate_gradient, "Conjugate gradient function");
      m.def("cusparse_SpMM", &cusparse_SpMM, "Sparse(CSR) times dense matrix multiplication on CUSPARSE");
    }

    TORCH_LIBRARY(torch_cuda, m) {
      m.def("conjugate_gradient", &conjugate_gradient);
      m.def("cusparse_SpMM", &cusparse_SpMM);
    }

The function linking works absolutely fine and the program behaviour is as expected.

My problem starts from here: I have added a JIT compilation using PyTorch JIT. Inside my C++ function I added a print to make sure that the function is getting called:

std::cout << "From C++ function" << std::endl;

After every simulation step I should see a message like this that summarizes the simulation statistics for future profiling.

3f2f8164-222e-4af9-9c7d-94e0aa41ae73,0,28.859,142.0,1e-05,2000,C++_opt,1,64

The output of my simulation is shown below. In theory, I should only see the From C++ function message once per simulation step. After the first simulation step my C++ function is not printing anything, I assume it is not getting called. The pixels that show the 2D simulation evolution do not change, meaning that there are no real computations going on. Running the debugger with a breakpoint on the Python call to the C++ function I observe the same behaviour: The debugger stops 4 times at that function and then all the steps are printed out without triggering any stop at that breakpoint.

From C++ function
From C++ function
From C++ function
From C++ function
/home/lanver/PhiFlow/venv/lib/python3.8/site-packages/torch/jit/_trace.py:810: TracerWarning:

Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!

Mismatched elements: 4096 / 4096 (100.0%)
Greatest absolute difference: 0.06343267313840288 at index (32, 3) (up to 1e-07 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 1e-05 allowed)

3f2f8164-222e-4af9-9c7d-94e0aa41ae73,0,28.859,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,1,27.114,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,2,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,3,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,4,0.06,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,5,0.062,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,6,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,7,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,8,0.061,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,9,0.064,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,10,0.06,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,11,0.062,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,12,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,13,0.063,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,14,0.058,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,15,0.062,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,16,0.059,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,17,0.064,142.0,1e-05,2000,C++_opt,1,64
3f2f8164-222e-4af9-9c7d-94e0aa41ae73,18,0.063,142.0,1e-05,2000,C++_opt,1,64


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source