'Pytorch: Difference between indexing and torch.stack

Is there any advantage for using indexing vs using torch.stack when constructing a tensor in pytorch?

Indexing

out = torch.empty(length, n)
for ii in range(length)
    out[ii] = f(ii)

torch.stack

out = [f(ii) for ii in range(length)]
out = torch.stack(out)

Benchmark

Benchmarking the two, it seems like torch.stack is consistently more than twice as fast as indexing:

[------------ 8 -------------]
             |  nojit  |  jit 
1 threads: -------------------
      index  |  113.9  |  99.1
      stack  |   60.5  |  51.6

Times are in microseconds (us).

[------------- 32 ------------]
             |  nojit  |   jit 
1 threads: --------------------
      index  |  450.4  |  385.8
      stack  |  198.1  |  174.5

Times are in microseconds (us).

[------------- 128 -------------]
             |  nojit   |   jit  
1 threads: ----------------------
      index  |  1805.9  |  1555.8
      stack  |   779.2  |   688.7

Times are in microseconds (us).

[----------- 256 -----------]
             |  nojit  |  jit
1 threads: ------------------
      index  |   3.6   |  3.0
      stack  |   1.5   |  1.4

Times are in milliseconds (ms).

[----------- 512 -----------]
             |  nojit  |  jit
1 threads: ------------------
      index  |   7.9   |  6.0
      stack  |   3.0   |  2.7

Times are in milliseconds (ms).

[----------- 1024 -----------]
             |  nojit  |  jit 
1 threads: -------------------
      index  |   14.4  |  12.6
      stack  |    6.1  |   5.1

Times are in milliseconds (ms).

Benchmark script

import torch
import torch.utils.benchmark as benchmark


def index(arr: torch.Tensor, length: int):
    # arr: (n, )
    (n,) = arr.shape

    out = torch.empty(length, n, device=arr.device)
    out[0] = arr

    for ii in range(1, length):
        out[ii] = out[ii - 1] + arr

    return out


def stack(arr: torch.Tensor, length: int) -> torch.Tensor:
    # arr: (n, )
    arrs = [arr]
    for ii in range(1, length):
        arrs.append(arrs[-1] + arr)

    return torch.stack(arrs, dim=0)


index_jit = torch.jit.script(index)
stack_jit = torch.jit.script(stack)


def main():
    torch.random.manual_seed(1234)
    n = 256

    x = torch.randn(n).cuda()

    lengths = [8, 32, 128, 256, 512, 1024]

    timers = []
    for length in lengths:
        label = f"{length}"
        globals_dict = {"x": x, "length": length}

        t_index = benchmark.Timer(
            stmt="index(x, length)",
            setup="from __main__ import index",
            label=label,
            sub_label="index",
            description="nojit",
            globals=globals_dict,
        )
        t_stack = benchmark.Timer(
            stmt="stack(x, length)",
            setup="from __main__ import stack",
            label=label,
            sub_label="stack",
            description="nojit",
            globals=globals_dict,
        )
        t_index_jit = benchmark.Timer(
            stmt="index_jit(x, length)",
            setup="from __main__ import index_jit",
            label=label,
            sub_label="index",
            description="jit",
            globals=globals_dict,
        )
        t_stack_jit = benchmark.Timer(
            stmt="stack_jit(x, length)",
            setup="from __main__ import stack_jit",
            label=label,
            sub_label="stack",
            description="jit",
            globals=globals_dict,
        )
        timers.extend([t_index, t_stack, t_index_jit, t_stack_jit])

    results = [t.blocked_autorange() for t in timers]

    compare = benchmark.Compare(results)
    compare.print()


if __name__ == "__main__":
    main()


Solution 1:[1]

Just like you say, stack is generally faster.

Usually we already have some tensors (output from a data loader or another torch operation) and can perform further torch operations on them. We try to avoid for-loops in pytorch so the heavily optimized pytorch functions can work faster with tensors on the gpu. A lot of overhead usually lies in the iteration of the for-loops.

You can try another benchmark where you have lready generated out = [f(ii) for ii in range(length)] and only perform the stack. Similarly, we try to write pytorch functions f() that support a batch as an input instead of a single sample (without using for-loops internally). If you do that, then you will only need to do an index assignment in a single call and that will be significantly faster as well.

This is the PyTorch way.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Zoom