'swapaxes and how it is implemented?

I'm wondering if someone can explain this code to me?

c = self.config

assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']

if c.orientation == 'per_column':
  pair_act = jnp.swapaxes(pair_act, -2, -3)
  pair_mask = jnp.swapaxes(pair_mask, -1, -2)

It looks like pair_act is a 3D array and pair_mask is a 2D array? What are the numbers -1, -2, and -3? For 3D arrays, my initial thought is that the array is 0, column is 1, and row is 2. So where does the - sign come from? Any array examples would be appreciated. Thanks for the help.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source