'GPU batch-processing of a large stack of matrices in PyTorch
I am quite new to pytorch so I am trying to write some examples. One of these example is the matrix-matrix product of M NxN matrices, i.e. a certain tensor contraction of two tensors of shape (M, N, N). The simplest way of doing that (on the CPU) is to use torch.bmm which directly operates on stacks of matrices. There is also an example in the link which I want to extend to processing on the GPU. If M is sufficiently small, it is possible to "dump" everything directly to the GPU by calling torch.Tensor.to:
N = 5
M = 10
input = torch.randn(M, N, N).to("cuda")
mat2 = torch.randn(M, N, N).to("cuda")
res = torch.bmm(input, mat2)
res
# Prints the entire tensor (and I assume performs the
# actual calculation)
res.size() == torch.Size([M, N, N])
# prints True
res.is_cuda
# prints True
However, if for example M and N are in the order of 1000, this becomes unfeasible because input and mat2 do not fit into the GPU memory (simultaneously) and the following exception is printed:
RuntimeError: CUDA out of memory. Tried to allocate 3.73 GiB (GPU 0; 2.00 GiB total capacity; 0 bytes already allocated; 1.56 GiB bytes free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I have read parts of the referenced document, but it does not seem to directly mention a solution for this problem. My idea would here to process these stacks of matrices along M in chunks that fit into the GPU memory while using as much available space as possible and copying back the processed data into the CPU memory after the torch.bmm call.
The most straightforward way of doing so would be to simply use slicing to split input and mat2, copy the resulting tensors into device memory, call torch.bmm with out set to some device tensor and copy the data back from there. This piece of code is then put inside a loop and one can choose an appropriate chunk size.
However, to me it feels like something like this should already be implemented because this seems like a quite common use case, even if this is only quite "low-level" compared to what pytorch is developed for, which is machine learning. However, I am quite new to this and do not have any hands-on experience with that.
Secondly, with this splitting it is also possible to use GPU(s) and CPU(s) in tandem to process the matrix stack for maybe even better performance. What would be a first step at implementing something like this or is something like this maybe already implemented in pytorch?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
