|
844 | 844 | "\"\"\""
|
845 | 845 | ]
|
846 | 846 | },
|
| 847 | + { |
| 848 | + "cell_type": "markdown", |
| 849 | + "metadata": {}, |
| 850 | + "source": [ |
| 851 | + "#### Strided Conv Gradient" |
| 852 | + ] |
| 853 | + }, |
| 854 | + { |
| 855 | + "cell_type": "code", |
| 856 | + "execution_count": 23, |
| 857 | + "metadata": {}, |
| 858 | + "outputs": [], |
| 859 | + "source": [ |
| 860 | + "lang = \"\"\"\n", |
| 861 | + "def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) -> (I_grad, W1_grad) {{\n", |
| 862 | + " I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw)\n", |
| 863 | + " W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w)\n", |
| 864 | + "}}\n", |
| 865 | + "\"\"\"" |
| 866 | + ] |
| 867 | + }, |
| 868 | + { |
| 869 | + "cell_type": "markdown", |
| 870 | + "metadata": {}, |
| 871 | + "source": [ |
| 872 | + "#### Simple Group Convolution" |
| 873 | + ] |
| 874 | + }, |
| 875 | + { |
| 876 | + "cell_type": "code", |
| 877 | + "execution_count": 25, |
| 878 | + "metadata": { |
| 879 | + "collapsed": true |
| 880 | + }, |
| 881 | + "outputs": [], |
| 882 | + "source": [ |
| 883 | + "lang = \"\"\"\n", |
| 884 | + "\n", |
| 885 | + "def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {\n", |
| 886 | + " O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw)\n", |
| 887 | + " O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)\n", |
| 888 | + "}\n", |
| 889 | + "\"\"\"" |
| 890 | + ] |
| 891 | + }, |
| 892 | + { |
| 893 | + "cell_type": "markdown", |
| 894 | + "metadata": {}, |
| 895 | + "source": [ |
| 896 | + "#### Group Conv. Strided" |
| 897 | + ] |
| 898 | + }, |
| 899 | + { |
| 900 | + "cell_type": "code", |
| 901 | + "execution_count": 26, |
| 902 | + "metadata": { |
| 903 | + "collapsed": true |
| 904 | + }, |
| 905 | + "outputs": [], |
| 906 | + "source": [ |
| 907 | + "lang = \"\"\"\n", |
| 908 | + "def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{\n", |
| 909 | + " O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw)\n", |
| 910 | + " O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)\n", |
| 911 | + "}}\n", |
| 912 | + "\"\"\"" |
| 913 | + ] |
| 914 | + }, |
847 | 915 | {
|
848 | 916 | "cell_type": "code",
|
849 | 917 | "execution_count": null,
|
|
0 commit comments