# Demystifying the Pytorch Memory Model: reshape(), permute(), contiguous() and More

(in editing.)

Ever cared about what side effects you might get by using magic spells like `permute()`

on your tensors?

## The Flexibility of Tensors

With the popularity of autograd frameworks (such as Pytorch, TensorFlow, MXNet, etc.) growing among researchers and practitioners, it’s not uncommon to see people build their ever-progressive models and pipelines using tons of tensor flippings, i.e., reshape, switching axes, adding new axes, etc. A modest example of chaining tensor operations might look like:

```
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
```

while some carried-away examples might go to the extreme like:

```
x.permute(...).reshape(...).unsqueeze(...).permute(...)
```

Although manipulating tensors in this way is totally legit, it does hurt readability. What’s more, since the framework usually abstracts away the underlying data allocation for us by providing N-D array interfaces (e.g., `torch.Tensor`

), we might end up with sub-optimal performance unknowingly even if we come up with some clever vectorization, if we don’t understand how exactly the data has been manipulated. Sometimes the code will even break in a mysterious way, for example:

```
>>> y = x.view(-1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
```

Recently I’ve been making extensive use of detectron2, a popular research framework on top of Pytorch maintained by Facebook AI Research (see FAIR). I found this interesting line from `detectron2.data.dataset_mapper.DatasetMapper`

:

```
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
```

Again, this is another example but this time making good use of Pytorch data model to squeeze out the performance. If you find this line a bit daunting but are curious enough about why it’s written this way, like I did, read on!

## Contiguous vs. Non-Contiguous Tensors

`torch.Storage`

Although tensors in Pytorch are generic N-D arrays, under the hood the they use `torch.Storage`

, a 1-D array structure to store the data, with each elements next to each other. Each tensor has an associated `torch.Storage`

and we can view the storage by calling `.storage()`

directly on it, for example:

```
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> x
tensor([[1, 2, 3],
[4, 5, 6]])
>>> x.storage()
1
2
3
4
5
6
[torch.LongStorage of size 6]
```

`x.stride()`

Then how are the 1-D data mapped to the N-D array and vice versa? Simple. Each tensor has a `.stride()`

method that returns a list of integers specifying how many jumps to make along the 1-D array in order to access the next element along the N axes of the N-D array. In the case of `x`

, its strides should be `(3, 1)`

, since from `x[0, 1]`

to `x[0, 2]`

we would only need 1 jump from `2`

to `3`

, but from `x[0, 1]`

to `x[1, 1]`

we would need 3 jumps from `2`

to `5`

.

```
>>> x.stride()
(3, 1)
```

As another example, creating an N-D random tensor using `torch.rand`

will give the tensor strides that look like:

```
>>> y = torch.rand(3, 4, 5)
>>> y.stride()
(20, 5, 1)
>>> y.stride(0)
20
```

with the k-th stride being the product of all dimensions that come after the k-th axis, e.g., `y.stride(0) == y.size(1) * y.size(2)`

, `y.stride(1) == y.size(0)`

, `y.stride(2) == 1`

. Mathmatically, we have \(S_k = \prod_{k + 1}^ND_i\). When unrolling the tensor from the least axis, starting from right to the left, its elements fall onto the 1-D storage view one by one. This feels natural, since strides seem to be determined by the dimensions of each axis only. In fact, this is the definition of being “contiguous”.

`x.is_contiguous()`

By definition, a contiguous array (or more precisely, C-contiguous) is one that whose 1-D data representation corresponds to unrolling itself starting from the least axis. In Pytorch the least axis corresponds to the rightmost index. Put it mathmatically, for an N-D array `X`

with shape `(D_1, D_2, ..., D_N)`

, and its associated 1-D representation `X_flat`

, the elements are laid out such that

To visualize the layout, given a contiguous array `x = torch.arange(12).reshape(3, 4)`

, it should look like (picture credit to Alex):