Open In App

How Does torch.argmax Work for 4-Dimensions in Pytorch

Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we are going to discuss how does the torch.argmax work for 4-Dimensions with detailed examples.

Torch.argmax() Method

Torch.argmax() method accepts a tensor and returns the indices of the maximum values of the input tensor across a specified dimension/axis. If the input tensor exists with multiple maximal values then the function will return the index of the first maximal element. Let’s look into the syntax of Torch.argmax() method along with its parameters. 

Syntax Torch.argmax(input_tensor, dim, keepdim=True)

Parameters

  • input_tensor is a tensor for which we find the maximum value indices.
  • dim is an integer value that specifies the dimension to reduce. If not specified then argmax of the flattened input is returned.
  • keepdim is a Boolean value that specify whether the output tensor has dim retained or not.

Returns the indices of the maximum values of a input_tensor across a dimension.

Working with argmax

In higher dimensions torch.argmax method returns the list of indices of maximum values according to the specified axis. we can understand it with an example of how argmax() method works on 2 Dimensional tensors.

Example: [[1,10], [20,15]]

Two Dimensional tensor has only two axis-0,1 (Rows & Columns)

  • Along axis-0 The argmax method check for the maximum value in each column and returns the one index position of maximum value in each column. So in first column 20 is the maximum value and index value of it is 1 and in second column 15 is the maximum value and it’s Index is 1. so argmax method returns [1,1] as result.
  • Along the axis-1 The argmax method check for the maximum value in each row and returns the one index position of maximum value in each row. So in first row 10 is the maximum value and index value of it is 1 and in second row 20 is the maximum value it’s index position is 0. so argmax method returns [1,0] as result.

How Does torch.argmax Work for 4-Dimensions

If we didn’t set the keepdims=True in argmax() method for a 4-Dimensional input tensor with shape [1,2,3,4] and with axis=0, it will return an output tensor of shape [2,3,4]. Whereas for axis=1 the argmax() method returns a tensor with shape [1,3,4] which will be similar to another axis. so when we apply argmax method across any axis/dimension by default it will collapses that axis or dimension because its values are replaced by a single index.

If we set the keepdims=True in argmax() method then it won’t remove that dimension instead, it keeps it as one. For example a 4-D tensor with shape [1,2,3,4] the argmax() method along the axis=1 returns a tensor with shape [1,1,3,4].

Example 1

Here in the below program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims=False or None.

Python3




# import necessary libraries
import torch
  
# define a random 4D tensor
A = torch.randn(1, 2, 3, 4)
print("Tensor-A:", A)
print(A.shape)
  
# use argmax method on 4d tensor along axis-0
print('---Output tensor along axis-0---')
print(torch.argmax(A, axis=0, keepdims=False))
print(torch.argmax(A, axis=0, keepdims=False).shape)
  
# use argmax method on 4d tensor along axis-2
print('---Output tensor along axis-2---')
print(torch.argmax(A, axis=2))
print(torch.argmax(A, axis=2).shape)


Output

Tensor-A: tensor([[[[ 0.2672,  0.6414, -0.7371, -0.8712],

          [ 0.9414, -1.2926, -1.0787,  1.7124],

          [-1.1063, -1.7132,  1.5767, -1.7195]],

         [[-0.7871, -1.3260,  0.1592, -0.0543],

          [ 1.8193, -1.8586, -0.6683,  0.3800],

          [ 1.8769, -0.9481, -0.4193,  0.4439]]]])

torch.Size([1, 2, 3, 4])

—Output tensor along axis-0—

tensor([[[0, 0, 0, 0],

         [0, 0, 0, 0],

         [0, 0, 0, 0]],

        [[0, 0, 0, 0],

         [0, 0, 0, 0],

         [0, 0, 0, 0]]])

torch.Size([2, 3, 4])

—Output tensor along axis-2—

tensor([[[1, 0, 2, 1],

         [2, 2, 0, 2]]])

torch.Size([1, 2, 4])

Example 2

Here in this program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims value is set to True.

Python3




# import necessary libraries
import torch
  
# define a random 4D tensor
A = torch.randn(1, 2, 3, 4)
print("Tensor-A:", A)
print(A.shape)
  
# use argmax method on 4d tensor along axis-2
print('---Output tensor along axis-2---')
print(torch.argmax(A, axis=2, keepdims=True))
print(torch.argmax(A, axis=2, keepdims=True).shape)
  
# use argmax method on 4d tensor along axis-3
print('---Output tensor along axis-3---')
print(torch.argmax(A, axis=3, keepdims=True))
print(torch.argmax(A, axis=3, keepdims=True).shape)


Output

Tensor-A: tensor([[[[ 0.8328, -0.6209,  0.0998,  0.4570],

          [ 0.1988, -0.2921,  1.7013, -0.8665],

          [ 0.6360,  0.0828,  0.3932,  0.2918]],

         [[ 0.0380, -0.0488,  1.0596,  0.8984],

          [-1.5110, -0.1987,  1.0706,  1.5212],

          [-0.0235,  0.3309,  0.8487, -1.9038]]]])

torch.Size([1, 2, 3, 4])

—Output tensor along axis-2—

tensor([[[[0, 2, 1, 0]],

         [[0, 2, 1, 1]]]])

torch.Size([1, 2, 1, 4])

—Output tensor along axis-3—

tensor([[[[0],

          [2],

          [0]],

         [[2],

          [3],

          [2]]]])

torch.Size([1, 2, 3, 1])



Last Updated : 03 Jun, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads