Pytorch Tensor Reshaping

Tensor reshaping is one of the most frequently used operations for data preparation and model training. Pytorch has in-built functions for tensor reshaping. In this chapter of Pytorch Tutorial, you will learn about tensor reshaping in Pytorch.

view()

view(*shape) when called on a tensor returns a view of the original tensor with the required shape. However, the number of elements in the required view of the tensor should be equal to that of the original tensor.

The returned tensor will be in the required shape but it will share the data with the original tensor. Therefore, view() will never return a copy of the data. Any changes to the data of the original tensor will always be reflected in a view of the tensor.

Example

tensor_1 = torch.tensor([[1, 2], [3, 4]])

# Create tensor_2 as a view of tensor_1
tensor_2 = tensor_1.view(4, 1)

print(tensor_2)
# Outputs- tensor([[1], [2], [3], [4]])

# Changing data in the original tensor
tensor_1[0][0] = -1

print(tensor_2)
# Outputs- tensor([[-1], [2], [3], [4]])

Note– Notice that changing the data of the original tensor changes the data of the new tensor created by using the view() function. Hence, tensor_2 shares the same data with the original tensor tensor_1.

Example

Calling view() function on a non-contiguous tensor will return an error.

tensor_1 = torch.tensor([[1, 2], [3, 4]])

# Storing the transpose of tensor_1 in a new tensor
# This makes the new tensor(tensor_2) non-contiguous
tensor_2 = tensor_1.t()

print(tensor_2.view(4, 1))
# Outputs- RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Example

You can create a new contiguous tensor from the original tensor using the contiguous() function. Now you can use the view() function on the new tensor.

tensor_1 = torch.tensor([[1, 2], [3, 4]])

# Storing the transpose of tensor_1 in a new tensor
# This makes the new tensor(tensor_2) non-contiguous
tensor_2 = tensor_1.t()

# tensor_contg is a contiguous tensor created from tensor_2 
tensor_contg = tensor_2.contiguous()

print(tensor_contg.view(4, 1))
# Outputs- tensor([[1], [3], [2], [4]])

torch.reshape()

torch.reshape(x, (*shape)) returns a tensor that will have the same data but will reshape the tensor to the required shape. However, the number of elements in the new tensor has to be the same as that of the original tensor.

reshape() function will return a view of the original tensor whenever the array is contiguous(or has contiguous strides). Otherwise, it will create and return a copy of the tensor with the required shape. Hence, the reshape function might return a view of the original tensor or create a copy of the tensor.

reshape() can also be called directly on the tensor.

Example

tensor_1 = torch.tensor([[1, 2], [3, 4]])

# Create tensor_2 from tensor_1 using reshape() function
tensor_2 = tensor_1.reshape(4, 1)

print(tensor_2)
# Outputs- tensor([[1], [2], [3], [4]])

# Changing data in the original tensor
tensor_1[0][0] = -1

print(tensor_2)
# Outputs- tensor([[-1], [2], [3], [4]])

Note– Notice when the tensor is contiguous, a view is returned.

Example

Calling reshape() function on a non-contiguous tensor will return a new tensor.

tensor_1 = torch.tensor([[1, 2], [3, 4]])

# Storing the transpose of tensor_1 in a new tensor
# This makes the new tensor(tensor_2) non-contiguous
tensor_2 = tensor_1.t()

# Calling the reshape() function on a non-contiguous tensor 
tensor_3 = tensor_2.reshape(4, 1)

# Changing data in the original tensor
tensor_1[0][0] = -1

print(tensor_1)
# Outputs- tensor([[-1,  2], [3,  4]])

print(tensor_3)
# Outputs- tensor([[1], [3], [2], [4]])

Note– Notice how calling the reshape() function on a non-contiguous tensor tensor_2 creates a new tensor tensor_3. tensor_3 is not a view of tensor_1 and hence does not share data with it. For this reason, when the data in the original tensor tensor_1 is changed, the data of tensor_3 remains unchanged.


Difference between reshape() and view()

While both, view() and reshape() return a tensor of the desired shape if it is possible. And both return an error when it is just not possible to return a tensor of the desired shape, there are a few differences between the two functions. These differences are compared in the table below.

view()reshape()
Returns a view of the original tensor.Returns a view of the array is contiguous else returns a copy of the tensor.
Shares the same underlying data with the original tensor.Shares the same underlying data with the original tensor if a view is returned. Otherwise, the data is stored in a new tensor.
Returns an error if the original array is not contiguous.If the original array is not contiguous, it creates a new tensor that is contiguous and returns it.
It can only be called directly on the tensor.Can be called using torch.reshape() or can be called directly on the tensor.

torch.permute()

permute(*dims) is used to re-arrange the dimensions of a tensor. It is quite useful when re-arranging the dimension of the tensor before feeding it to the network. For example, while re-arranging a tensor storing an image in the form [height, width, channel] to [channel, height, width] before feeding this data to a neural network.

You pass the indexes of the tensor dimensions as of arguments to permute() in the order you want them to appear in the new tensor.

Example

# Creating a random tensor to denote a 299*299 RGB image
img_tensor = torch.rand(299, 299, 3)

print(img_tensor.shape)
# Outputs- torch.Size([299, 299, 3])

# Re-arranging the tensor from [height, width, channel] to [channel, height, width]
rearranged_img_tensor=img_tensor.permute(2, 0, 1)

print(rearranged_img_tensor.shape)
# Outputs- torch.Size([3, 299, 299])

We wanted the last dimension(channel dimension) to be the first dimension in the new tensor. Hence, we passed the index of the last dimension- 2, as the first argument to the permute function. Similarly, we wanted the height and width dimensions to be second and third respectively. Hence, we passed their indexes- 0 and 1 respectively, in the required order to the permute() function.