Open In App

How to squeeze and unsqueeze a tensor in PyTorch?

Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we will understand how to squeeze and unsqueeze a PyTorch Tensor.  

To squeeze a tensor we can apply the torch.squeeze() method and to unsqueeze a tensor we use the torch.unsqueeze() method. Let’s understand these methods in detail.

Squeeze a Tensor:

When we squeeze a tensor, the dimensions of size 1 are removed. The elements of the original tensor are arranged with the remaining dimensions. For example, if the input tensor is of shape: (m×1×n×1) then the output tensor after squeeze will be of shape: (m×n). The following is the syntax of the torch.squeeze() method.

Syntax: torch.squeeze(input, dim=None, *, out=None)

Parameters:

  • input: the input tensor.
  • dim: an optional integer value, if given the input is squeezed in this dimension.
  • out: the output tensor, an optional key argument.

Return: It returns a tensor with all the dimensions of input tensor of size 1 removed.

Please note that we can squeeze the input tensor in a particular dimension dim. In this case, other dimensions of size 1 will remain unchanged. We have discussed Example 2 in more detail.

Example 1:

In the example below we squeeze a 5D tensor using torch.squeeze() method. The input tensor has two dimensions of size 1.

Python3




# Python program to squeeze the tensor
# importing torch
import torch
 
# creating the input tensor
input = torch.randn(3,1,2,1,4)
# print the input tensor
print("Input tensor Size:\n",input.size())
 
# squeeze the tensor
output = torch.squeeze(input)
# print the squeezed tensor
print("Size after squeeze:\n",output.size())


Output:

Input tensor Size:
 torch.Size([3, 1, 2, 1, 4])
Size after squeeze:
 torch.Size([3, 2, 4])

Notice that both dimensions of size 1 are removed in the squeezed tensor.

Example 2:

In this example, We squeeze the tensor into different dimensions. 

Python3




# Python program to squeeze the tensor in
# different dimensions
 
# importing torch
import torch
# creating the input tensor
input = torch.randn(3,1,2,1,4)
print("Dimension of input tensor:", input.dim())
print("Input tensor Size:\n",input.size())
 
# squeeze the tensor in dimension 0
output = torch.squeeze(input,dim=0)
print("Size after squeeze with dim=0:\n",
      output.size())
 
# squeeze the tensor in dimension 0
output = torch.squeeze(input,dim=1)
print("Size after squeeze with dim=1:\n",
      output.size())
 
# squeeze the tensor in dimension 0
output = torch.squeeze(input,dim=2)
print("Size after squeeze with dim=2:\n",
      output.size())
 
# squeeze the tensor in dimension 0
output = torch.squeeze(input,dim=3)
print("Size after squeeze with dim=3:\n",
      output.size())
 
# squeeze the tensor in dimension 0
output = torch.squeeze(input,dim=4)
print("Size after squeeze with dim=4:\n",
      output.size())
# output = torch.squeeze(input,dim=5) # Error


Output:

Dimension of input tensor: 5
Input tensor Size:
 torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=0:
 torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=1:
 torch.Size([3, 2, 1, 4])
Size after squeeze with dim=2:
 torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=3:
 torch.Size([3, 1, 2, 4])
Size after squeeze with dim=4:
 torch.Size([3, 1, 2, 1, 4])

Notice that when we squeeze the tensor in dimension 0, there is no change in the shape of the output tensor. When we squeeze in dimension 1 or in dimension 3 (both are of size 1), only this dimension is removed in the output tensor. When we squeeze in dimension 2 or in dimension 4, there is no change in the shape of the output tensor.

Unsqueeze a Tensor:

When we unsqueeze a tensor, a new dimension of size 1 is inserted at the specified position.  Always an unsqueeze operation increases the dimension of the output tensor. For example, if the input tensor is of shape:  (m×n) and we want to insert a new dimension at position 1 then the output tensor after unsqueeze will be of shape: (m×1×n). The following is the syntax of the torch.unsqueeze() method-

Syntax: torch.unsqueeze(input, dim)

Parameters:

  • input: the input tensor.
  • dim: an integer value, the index at which the singleton dimension is inserted.

Return: It returns a new tensor with a dimension of size one inserted at the specified position dim.

Please note that we can choose the dim value from the range [-input.dim() – 1, input.dim() + 1). The negative dim will correspond to dim = dim + input.dim() + 1.

Example 3:

In the example below we unsqueeze a 1-D tensor to a 2D tensor.

Python3




# Python program to unsqueeze the input tensor
 
# importing torch
import torch
 
# define the input tensor
input = torch.arange(8, dtype=torch.float)
print("Input tensor:\n", input)
print("Size of input Tensor before unsqueeze:\n",
      input.size())
 
output = torch.unsqueeze(input, dim=0)
print("Tensor after unsqueeze with dim=0:\n", output)
print("Size after unsqueeze with dim=0:\n",
      output.size())
 
output = torch.unsqueeze(input, dim=1)
print("Tensor after unsqueeze with dim=1:\n", output)
print("Size after unsqueeze with dim=1:\n",
      output.size())


Output:

Input tensor:
 tensor([0., 1., 2., 3., 4., 5., 6., 7.])
Size of input Tensor before unsqueeze:
 torch.Size([8])
Tensor after unsqueeze with dim=0:
 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]])
Size after unsqueeze with dim=0:
 torch.Size([1, 8])
Tensor after unsqueeze with dim=1:
 tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.]])
Size after unsqueeze with dim=1:
 torch.Size([8, 1])


Last Updated : 23 May, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads