'swapaxes and how it is implemented?
I'm wondering if someone can explain this code to me?
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
It looks like pair_act is a 3D array and pair_mask is a 2D array? What are the numbers -1, -2, and -3? For 3D arrays, my initial thought is that the array is 0, column is 1, and row is 2. So where does the - sign come from? Any array examples would be appreciated. Thanks for the help.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
