a=torch.Tensor([
[[1,1],[2,2],[3,3]],
[[4,4],[5,5],[6,6]],
])
b=torch.Tensor([
[1.5,2.5],
[3.5,4.5],
])
print(a.size()) # torch.Size([2, 3, 2])
print(b.size()) # torch.Size([2, 2])
c=a-b
有没有不需要 unsqueeze 和 repeat ,直接相减的方法,听说可以节省内存
就是想让[1,1],[2,2],[3,3]都减去[1.5,2.5],[4,4],[5,5],[6,6]都减去[3.5,4.5]
1
bravecarrot 2021-11-13 18:59:38 +08:00
broadcasting
|
2
ekidona 2021-11-14 13:35:11 +08:00 via iPhone
a.permute(1,0,2)-b
|