'How to implement the FISTA/ISTA or the ADMM in the pytorch

Now I want to solve a lasso problem(the problem i worked is called phase retrieval). The loss function would be something like|| A(x) -b ||_2^2+||x ||_1 $. A(x) is a non-convex and non-linear function. Now I want to use the pytorch automatic differential to solve this problem. However, if i just use the pytorch to minimize the loss function, the result is not good. I don't know why, maybe becaues pytorch use the subgradient method, which is not as good as proximal gradient descent method? Now if i want to implemement what the FISTA or ISTA or the ADMM doing, how should I do in pytorch. My idea is that for the gradient update step, i set the loss to be just || A(x) -b||_2^2$ and loss.backward(), the pytorch will update it automatically. Then in the with torch.no_grad():, i will do a proximal projection, which i use the torch.nn.Softshrink() function. Do I did it correctly or maybe there is somewhere i should notice? I can not find too many people doing in this way.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source