Skip to content

Commit 6f8b7b8

Browse files
committed
More DLDB Comprehensions
1 parent 4a92f75 commit 6f8b7b8

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

pytorch/Tensor Comprehensions - Getting Started.ipynb

+68
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,74 @@
844844
"\"\"\""
845845
]
846846
},
847+
{
848+
"cell_type": "markdown",
849+
"metadata": {},
850+
"source": [
851+
"#### Strided Conv Gradient"
852+
]
853+
},
854+
{
855+
"cell_type": "code",
856+
"execution_count": 23,
857+
"metadata": {},
858+
"outputs": [],
859+
"source": [
860+
"lang = \"\"\"\n",
861+
"def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) -> (I_grad, W1_grad) {{\n",
862+
" I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw)\n",
863+
" W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w)\n",
864+
"}}\n",
865+
"\"\"\""
866+
]
867+
},
868+
{
869+
"cell_type": "markdown",
870+
"metadata": {},
871+
"source": [
872+
"#### Simple Group Convolution"
873+
]
874+
},
875+
{
876+
"cell_type": "code",
877+
"execution_count": 25,
878+
"metadata": {
879+
"collapsed": true
880+
},
881+
"outputs": [],
882+
"source": [
883+
"lang = \"\"\"\n",
884+
"\n",
885+
"def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {\n",
886+
" O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw)\n",
887+
" O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)\n",
888+
"}\n",
889+
"\"\"\""
890+
]
891+
},
892+
{
893+
"cell_type": "markdown",
894+
"metadata": {},
895+
"source": [
896+
"#### Group Conv. Strided"
897+
]
898+
},
899+
{
900+
"cell_type": "code",
901+
"execution_count": 26,
902+
"metadata": {
903+
"collapsed": true
904+
},
905+
"outputs": [],
906+
"source": [
907+
"lang = \"\"\"\n",
908+
"def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{\n",
909+
" O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw)\n",
910+
" O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)\n",
911+
"}}\n",
912+
"\"\"\""
913+
]
914+
},
847915
{
848916
"cell_type": "code",
849917
"execution_count": null,

0 commit comments

Comments
 (0)