Demystifying the Pytorch Memory Model: reshape(), permute(), contiguous() and More

(in editing.)

Ever cared about what side effects you might get by using magic spells like permute() on your tensors?

The Flexibility of Tensors

With the popularity of autograd frameworks (such as Pytorch, TensorFlow, MXNet, etc.) growing among researchers and practitioners, it’s not uncommon to see people build their ever-progressive models and pipelines using tons of tensor flippings, i.e., reshape, switching axes, adding new axes, etc. A modest example of chaining tensor operations might look like:

bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)

while some carried-away examples might go to the extreme like:


Although manipulating tensors in this way is totally legit, it does hurt readability. What’s more, since the framework usually abstracts away the underlying data allocation for us by providing N-D array interfaces (e.g., torch.Tensor), we might end up with sub-optimal performance unknowingly even if we come up with some clever vectorization, if we don’t understand how exactly the data has been manipulated. Sometimes the code will even break in a mysterious way, for example:

>>> y = x.view(-1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Recently I’ve been making extensive use of detectron2, a popular research framework on top of Pytorch maintained by Facebook AI Research (see FAIR). I found this interesting line from

dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))

Again, this is another example but this time making good use of Pytorch data model to squeeze out the performance. If you find this line a bit daunting but are curious enough about why it’s written this way, like I did, read on!

Contiguous vs. Non-Contiguous Tensors


Although tensors in Pytorch are generic N-D arrays, under the hood the they use torch.Storage, a 1-D array structure to store the data, with each elements next to each other. Each tensor has an associated torch.Storage and we can view the storage by calling .storage() directly on it, for example:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> x
tensor([[1, 2, 3],
        [4, 5, 6]])
[torch.LongStorage of size 6]


Then how are the 1-D data mapped to the N-D array and vice versa? Simple. Each tensor has a .stride() method that returns a list of integers specifying how many jumps to make along the 1-D array in order to access the next element along the N axes of the N-D array. In the case of x, its strides should be (3, 1), since from x[0, 1] to x[0, 2] we would only need 1 jump from 2 to 3, but from x[0, 1] to x[1, 1] we would need 3 jumps from 2 to 5.

>>> x.stride()
(3, 1)

As another example, creating an N-D random tensor using torch.rand will give the tensor strides that look like:

>>> y = torch.rand(3, 4, 5)
>>> y.stride()
(20, 5, 1)
>>> y.stride(0)

with the k-th stride being the product of all dimensions that come after the k-th axis, e.g., y.stride(0) == y.size(1) * y.size(2), y.stride(1) == y.size(0), y.stride(2) == 1. Mathmatically, we have \(S_k = \prod_{k + 1}^ND_i\). When unrolling the tensor from the least axis, starting from right to the left, its elements fall onto the 1-D storage view one by one. This feels natural, since strides seem to be determined by the dimensions of each axis only. In fact, this is the definition of being “contiguous”.


By definition, a contiguous array (or more precisely, C-contiguous) is one that whose 1-D data representation corresponds to unrolling itself starting from the least axis. In Pytorch the least axis corresponds to the rightmost index. Put it mathmatically, for an N-D array X with shape (D_1, D_2, ..., D_N), and its associated 1-D representation X_flat, the elements are laid out such that

\[X\left [ k_1, k_2, ..., k_N \right ] = X_{flat}[k_1 \times \prod_2^ND_i + k_2 \times \prod_3^ND_i + ... + k_N]\]

To visualize the layout, given a contiguous array x = torch.arange(12).reshape(3, 4), it should look like (picture credit to Alex):

Transposing is the Devil

Reshaping Does no “Harm”