Skip to content

Commit b048c4d

Browse files
author
Dolly Ye
authored
torch.dot does not broadcast
torch.dot() can only work for 1 dimension tensor. http://pytorch.org/docs/master/torch.html pytorch/pytorch#2313
1 parent 0b09c05 commit b048c4d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tutorial-contents-notebooks/201_torch_numpy.ipynb

+4-2
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,11 @@
279279
"source": [
280280
"# incorrect method\n",
281281
"data = np.array(data)\n",
282+
"tensor = torch.Tensor([1,2,3,4]\n",
282283
"print(\n",
283284
" '\\nmatrix multiplication (dot)',\n",
284285
" '\\nnumpy: ', data.dot(data), # [[7, 10], [15, 22]]\n",
285-
" '\\ntorch: ', tensor.dot(tensor) # this will convert tensor to [1,2,3,4], you'll get 30.0\n",
286+
" '\\ntorch: ', torch.dot(tensor.dot(tensor) # 30.0. Beware that torch.dot does not broadcast, only works for 1-dimensional tensor\n",
286287
")"
287288
]
288289
},
@@ -360,7 +361,8 @@
360361
}
361362
],
362363
"source": [
363-
"tensor.dot(tensor)"
364+
"torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1]))
365+
7.0"
364366
]
365367
},
366368
{

0 commit comments

Comments
 (0)