How to Make a grid of Images in PyTorch?
In this article, we are going to see How to Make a grid of Images in PyTorch. we can make a grid of images using the make_grid() function of torchvision.utils package.
make_grid() function:
The make_grid() function accept 4D tensor with [B, C ,H ,W] shape. where B represents the batch size, C represents the number of channels, and H, W represents the height and width respectively. The height and weight should be the same for all images. This function returns the tensor that contains a grid of input images. we can also set the number of images displayed in each row by using nrow parameter. The below syntax is used to make a grid of images in PyTorch.
Syntax: torchvision.utils.make_grid(tensor)
Parameter:
- tensor (Tensor or list) tensor of shape (B x C x H x W) or a list of images all of the same size.
- nrow (int, optional) – Number of images displayed in each row of the grid. Default: 8.
- padding (int, optional) – amount of padding. Default: 2.
Returns: This function returns the tensor that contains a grid of input images.
Example 1:
The following example is to understand how to make a grid of images in PyTorch.
Python3
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
a = read_image( 'a.jpg' )
b = read_image( 'b.jpg' )
c = read_image( 'c.jpg' )
d = read_image( 'd.jpg' )
Grid = make_grid([a, b, c, d])
img = torchvision.transforms.ToPILImage()(Grid)
img.show()
|
Output:
Example 2:
in the following example, we make a grid of images and set the number of images displayed in each row by using nrow Parameter.
Python3
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
a = read_image( 'a.jpg' )
b = read_image( 'b.jpg' )
c = read_image( 'c.jpg' )
d = read_image( 'd.jpg' )
e = read_image( 'e.jpg' )
f = read_image( 'f.jpg' )
Grid = make_grid([a, b, c, d, e, f], nrow = 3 )
img = torchvision.transforms.ToPILImage()(Grid)
img.show()
|
Output:
Example 3:
In the following example, we make a grid of images and set the padding between the images.
Python3
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
a = read_image( 'a.png' )
b = read_image( 'b.png' )
c = read_image( 'c.png' )
d = read_image( 'd.png' )
e = read_image( 'e.png' )
f = read_image( 'f.png' )
Grid = make_grid([a, b, c, d, e, f], nrow = 3 , padding = 25 )
img = torchvision.transforms.ToPILImage()(Grid)
img.show()
|
Output:
Last Updated :
03 Jun, 2022
Like Article
Save Article
Share your thoughts in the comments
Please Login to comment...