一、torch.max(input, dim) 函数
output = torch.max(input, dim)
输入:
input 是一个tensor
dim 是 max 函数索引的维度,dim 为 0 时返回每列最大值,dim 为 1 时返回每行最大值
输出:
函数会返回两个tensor,第一个 tensor 是某维度(dim)上的最大值;第二个 tensor 是最大值的索引(位置)。
二、实例
import torch
a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
print(a)
# dim 为 1 时返回每行最大值
print(torch.max(a, 1))
print(torch.max(a, 1)[1].numpy())
输出结果:
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
torch.return_types.max(
values=tensor([62, 6, 65]),
indices=tensor([2, 1, 1]))
[2 1 1]
参考链接
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/118937.html