How to squeeze and unsqueeze a tensor in PyTorch?
In this article, we will understand how to squeeze and unsqueeze a PyTorch Tensor.
To squeeze a tensor we can apply the torch.squeeze() method and to unsqueeze a tensor we use the torch.unsqueeze() method. Let’s understand these methods in detail.
Squeeze a Tensor:
When we squeeze a tensor, the dimensions of size 1 are removed. The elements of the original tensor are arranged with the remaining dimensions. For example, if the input tensor is of shape: (m×1×n×1) then the output tensor after squeeze will be of shape: (m×n). The following is the syntax of the torch.squeeze() method.
Syntax: torch.squeeze(input, dim=None, *, out=None)
Parameters:
- input: the input tensor.
- dim: an optional integer value, if given the input is squeezed in this dimension.
- out: the output tensor, an optional key argument.
Return: It returns a tensor with all the dimensions of input tensor of size 1 removed.
Please note that we can squeeze the input tensor in a particular dimension dim. In this case, other dimensions of size 1 will remain unchanged. We have discussed Example 2 in more detail.
Example 1:
In the example below we squeeze a 5D tensor using torch.squeeze() method. The input tensor has two dimensions of size 1.
Python3
import torch
input = torch.randn( 3 , 1 , 2 , 1 , 4 )
print ( "Input tensor Size:\n" , input .size())
output = torch.squeeze( input )
print ( "Size after squeeze:\n" ,output.size())
|
Output:
Input tensor Size:
torch.Size([3, 1, 2, 1, 4])
Size after squeeze:
torch.Size([3, 2, 4])
Notice that both dimensions of size 1 are removed in the squeezed tensor.
Example 2:
In this example, We squeeze the tensor into different dimensions.
Python3
import torch
input = torch.randn( 3 , 1 , 2 , 1 , 4 )
print ( "Dimension of input tensor:" , input .dim())
print ( "Input tensor Size:\n" , input .size())
output = torch.squeeze( input ,dim = 0 )
print ( "Size after squeeze with dim=0:\n" ,
output.size())
output = torch.squeeze( input ,dim = 1 )
print ( "Size after squeeze with dim=1:\n" ,
output.size())
output = torch.squeeze( input ,dim = 2 )
print ( "Size after squeeze with dim=2:\n" ,
output.size())
output = torch.squeeze( input ,dim = 3 )
print ( "Size after squeeze with dim=3:\n" ,
output.size())
output = torch.squeeze( input ,dim = 4 )
print ( "Size after squeeze with dim=4:\n" ,
output.size())
|
Output:
Dimension of input tensor: 5
Input tensor Size:
torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=0:
torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=1:
torch.Size([3, 2, 1, 4])
Size after squeeze with dim=2:
torch.Size([3, 1, 2, 1, 4])
Size after squeeze with dim=3:
torch.Size([3, 1, 2, 4])
Size after squeeze with dim=4:
torch.Size([3, 1, 2, 1, 4])
Notice that when we squeeze the tensor in dimension 0, there is no change in the shape of the output tensor. When we squeeze in dimension 1 or in dimension 3 (both are of size 1), only this dimension is removed in the output tensor. When we squeeze in dimension 2 or in dimension 4, there is no change in the shape of the output tensor.
Unsqueeze a Tensor:
When we unsqueeze a tensor, a new dimension of size 1 is inserted at the specified position. Always an unsqueeze operation increases the dimension of the output tensor. For example, if the input tensor is of shape: (m×n) and we want to insert a new dimension at position 1 then the output tensor after unsqueeze will be of shape: (m×1×n). The following is the syntax of the torch.unsqueeze() method-
Syntax: torch.unsqueeze(input, dim)
Parameters:
- input: the input tensor.
- dim: an integer value, the index at which the singleton dimension is inserted.
Return: It returns a new tensor with a dimension of size one inserted at the specified position dim.
Please note that we can choose the dim value from the range [-input.dim() – 1, input.dim() + 1). The negative dim will correspond to dim = dim + input.dim() + 1.
Example 3:
In the example below we unsqueeze a 1-D tensor to a 2D tensor.
Python3
import torch
input = torch.arange( 8 , dtype = torch. float )
print ( "Input tensor:\n" , input )
print ( "Size of input Tensor before unsqueeze:\n" ,
input .size())
output = torch.unsqueeze( input , dim = 0 )
print ( "Tensor after unsqueeze with dim=0:\n" , output)
print ( "Size after unsqueeze with dim=0:\n" ,
output.size())
output = torch.unsqueeze( input , dim = 1 )
print ( "Tensor after unsqueeze with dim=1:\n" , output)
print ( "Size after unsqueeze with dim=1:\n" ,
output.size())
|
Output:
Input tensor:
tensor([0., 1., 2., 3., 4., 5., 6., 7.])
Size of input Tensor before unsqueeze:
torch.Size([8])
Tensor after unsqueeze with dim=0:
tensor([[0., 1., 2., 3., 4., 5., 6., 7.]])
Size after unsqueeze with dim=0:
torch.Size([1, 8])
Tensor after unsqueeze with dim=1:
tensor([[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.]])
Size after unsqueeze with dim=1:
torch.Size([8, 1])
Last Updated :
23 May, 2022
Like Article
Save Article
Share your thoughts in the comments
Please Login to comment...