【Pytorch】torch.max() 函数详解

导读:本篇文章讲解 【Pytorch】torch.max() 函数详解,希望对大家有帮助,欢迎收藏,转发!站点地址:www.bmabk.com


一、一个参数时的 torch.max()

1. 函数介绍

torch.max(input)

参数:

  • input (Tensor) – 输入张量
  • 返回输入张量所有元素中的最大值。

2. 实例

import torch

# 返回张量中的最大值
a = torch.randn(2, 3)
print(a)
print(torch.max(a))

输出结果:

tensor([[ 0.0031, -0.5391, -0.9214],
        [-0.4647, -1.9750,  0.6924]])
tensor(0.6924)

二、增加指定维度时的 torch.max()

1. 函数介绍

torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)

返回张量 input 在压缩指定维度 dim 时的最大值及其下标。

2. 实例

import torch

# 返回张量在压缩指定维度时的最大值及其下标
b = torch.randn(4, 4)
print(b)
print(torch.max(b, 0))  # 指定0维,压缩0维,0维消失,也就是行消失,返回列最大值及其下标
print(torch.max(b, 1))  # 指定1维,压缩1维,1维消失,也就是列消失,返回行最大值及其下标

输出结果:

tensor([[-0.8862,  0.3502,  0.0223,  0.6035],
        [-2.0135, -0.1346,  2.0575,  1.4203],
        [ 1.0107,  0.9302, -0.1321,  0.0704],
        [-1.4540, -0.4780,  0.7016,  0.3029]])
torch.return_types.max(
values=tensor([1.0107, 0.9302, 2.0575, 1.4203]),
indices=tensor([2, 2, 1, 1]))
torch.return_types.max(
values=tensor([0.6035, 2.0575, 1.0107, 0.7016]),
indices=tensor([3, 2, 0, 2]))

三、两个输入张量时的 torch.max()

1. 函数介绍

torch.max(input, other_input, out=None) → Tensor

返回两张量 input 和 other_input 在对应位置上的最大值形成的新张量。

2. 实例

import torch

# 返回两张量对应位置上的最大值
c = torch.randn(4,2)
d = torch.randn(4,2)
print(c)
print(d)
print(torch.max(c, d))

输出结果:

tensor([[ 0.6778,  1.2714],
        [-0.9020, -1.3789],
        [ 0.8541,  1.2193],
        [-0.8481, -0.8211]])
tensor([[ 2.4616, -1.2502],
        [ 0.0173, -0.5501],
        [ 1.0224, -1.5892],
        [ 1.3325,  0.2587]])
tensor([[ 2.4616,  1.2714],
        [ 0.0173, -0.5501],
        [ 1.0224,  1.2193],
        [ 1.3325,  0.2587]])

参考链接

  1. 详解 torch.max 函数

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/118860.html

(0)
seven_的头像seven_bm

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!