|
54 | 54 | "print(f'k2i={k2i}')"
|
55 | 55 | ]
|
56 | 56 | },
|
| 57 | + { |
| 58 | + "cell_type": "code", |
| 59 | + "execution_count": 73, |
| 60 | + "metadata": {}, |
| 61 | + "outputs": [ |
| 62 | + { |
| 63 | + "data": { |
| 64 | + "text/plain": [ |
| 65 | + "0 18\n", |
| 66 | + "1 16\n", |
| 67 | + "2 14\n", |
| 68 | + "3 13\n", |
| 69 | + "Name: y, dtype: int64" |
| 70 | + ] |
| 71 | + }, |
| 72 | + "execution_count": 73, |
| 73 | + "metadata": {}, |
| 74 | + "output_type": "execute_result" |
| 75 | + } |
| 76 | + ], |
| 77 | + "source": [ |
| 78 | + "df.y.value_counts()" |
| 79 | + ] |
| 80 | + }, |
57 | 81 | {
|
58 | 82 | "cell_type": "code",
|
59 | 83 | "execution_count": 50,
|
|
771 | 795 | "execution_count": 51,
|
772 | 796 | "metadata": {},
|
773 | 797 | "outputs": [],
|
774 |
| - "source": [ |
775 |
| - "from sklearn.linear_model import LogisticRegression\n", |
776 |
| - "\n", |
777 |
| - "X = df[[c for c in df.columns if c != 'y']]\n", |
778 |
| - "y = df.y\n", |
779 |
| - "\n", |
780 |
| - "clf = LogisticRegression(random_state=37, max_iter=1000).fit(X, y)" |
781 |
| - ] |
782 |
| - }, |
783 |
| - { |
784 |
| - "cell_type": "code", |
785 |
| - "execution_count": 52, |
786 |
| - "metadata": {}, |
787 |
| - "outputs": [ |
788 |
| - { |
789 |
| - "data": { |
790 |
| - "text/plain": [ |
791 |
| - "array([[9.99857751e-01, 7.14828650e-05, 1.78647482e-07, 7.05878308e-05],\n", |
792 |
| - " [9.99999873e-01, 1.10509918e-09, 1.74693620e-16, 1.25734063e-07],\n", |
793 |
| - " [9.99996341e-01, 1.32614673e-09, 7.42270612e-12, 3.65778682e-06],\n", |
794 |
| - " [9.99990065e-01, 4.32103437e-08, 6.42518012e-10, 9.89129664e-06],\n", |
795 |
| - " [9.99969350e-01, 2.54127838e-07, 2.94577247e-10, 3.03959933e-05],\n", |
796 |
| - " [9.99960952e-01, 4.76835060e-07, 1.90433708e-06, 3.66669567e-05],\n", |
797 |
| - " [9.99971594e-01, 2.12520998e-05, 6.02742194e-09, 7.14765181e-06],\n", |
798 |
| - " [9.99749888e-01, 2.49903134e-04, 1.91304507e-09, 2.06685335e-07],\n", |
799 |
| - " [9.99953476e-01, 4.36434926e-05, 3.99399209e-09, 2.87674889e-06],\n", |
800 |
| - " [9.99974708e-01, 2.63393212e-08, 3.91232222e-10, 2.52653882e-05],\n", |
801 |
| - " [9.99956720e-01, 4.44109840e-09, 3.37111338e-08, 4.32422851e-05],\n", |
802 |
| - " [9.99988844e-01, 1.10322220e-05, 3.68991538e-10, 1.23561436e-07],\n", |
803 |
| - " [9.99992981e-01, 2.67520562e-09, 1.49354589e-08, 7.00089754e-06],\n", |
804 |
| - " [9.99999427e-01, 4.40712129e-09, 2.06789051e-07, 3.62172519e-07],\n", |
805 |
| - " [9.99999940e-01, 1.34161358e-11, 3.95076232e-11, 6.04124503e-08],\n", |
806 |
| - " [9.99960640e-01, 1.39934737e-05, 3.09996241e-06, 2.22668243e-05],\n", |
807 |
| - " [9.99799809e-01, 1.64692960e-06, 2.85038654e-07, 1.98259008e-04],\n", |
808 |
| - " [9.99603280e-01, 1.33109265e-05, 1.40506103e-06, 3.82003515e-04],\n", |
809 |
| - " [2.09229849e-12, 9.99999977e-01, 1.05025421e-16, 2.32760723e-08],\n", |
810 |
| - " [1.71961605e-12, 9.99998688e-01, 3.56516527e-17, 1.31154561e-06],\n", |
811 |
| - " [1.11761695e-12, 9.99999977e-01, 7.12318625e-17, 2.30887938e-08],\n", |
812 |
| - " [1.57326986e-12, 9.99998921e-01, 3.44324441e-17, 1.07925189e-06],\n", |
813 |
| - " [1.97076943e-10, 9.99948654e-01, 7.90461918e-13, 5.13462660e-05],\n", |
814 |
| - " [2.43709302e-05, 9.99331510e-01, 4.31265625e-04, 2.12853863e-04],\n", |
815 |
| - " [1.00831498e-11, 9.99997244e-01, 1.26767860e-16, 2.75565003e-06],\n", |
816 |
| - " [5.96857500e-05, 9.98668899e-01, 5.10298703e-12, 1.27141501e-03],\n", |
817 |
| - " [5.08951663e-10, 9.99998919e-01, 3.86519351e-14, 1.08070169e-06],\n", |
818 |
| - " [2.04403247e-09, 9.99995505e-01, 7.40928347e-13, 4.49265969e-06],\n", |
819 |
| - " [2.62203638e-06, 9.98512730e-01, 5.44955709e-07, 1.48410273e-03],\n", |
820 |
| - " [1.50683309e-10, 9.99985434e-01, 1.93992680e-13, 1.45662527e-05],\n", |
821 |
| - " [9.03486290e-18, 9.99999337e-01, 9.82848687e-28, 6.63374159e-07],\n", |
822 |
| - " [9.33108527e-08, 9.99996483e-01, 9.94055120e-13, 3.42327558e-06],\n", |
823 |
| - " [5.94780494e-06, 9.99925512e-01, 1.08401219e-11, 6.85402587e-05],\n", |
824 |
| - " [2.15986496e-05, 9.99952242e-01, 3.02417560e-09, 2.61558715e-05],\n", |
825 |
| - " [8.21561097e-12, 8.32749629e-16, 9.99998609e-01, 1.39052160e-06],\n", |
826 |
| - " [1.35601995e-12, 2.28917607e-15, 9.99999239e-01, 7.60802957e-07],\n", |
827 |
| - " [2.33754103e-04, 6.34382468e-05, 9.99557187e-01, 1.45620457e-04],\n", |
828 |
| - " [4.83552266e-11, 1.70808956e-14, 9.99994960e-01, 5.03955585e-06],\n", |
829 |
| - " [6.80387431e-08, 5.55408412e-11, 9.99880669e-01, 1.19262462e-04],\n", |
830 |
| - " [9.61223274e-06, 1.41735509e-07, 9.99521777e-01, 4.68468995e-04],\n", |
831 |
| - " [5.27033040e-13, 2.11060231e-16, 9.99938114e-01, 6.18855911e-05],\n", |
832 |
| - " [7.19939600e-19, 4.58281087e-24, 9.99999853e-01, 1.47432720e-07],\n", |
833 |
| - " [6.91717948e-10, 2.33101504e-12, 9.99863531e-01, 1.36468264e-04],\n", |
834 |
| - " [2.02543459e-15, 1.99977575e-18, 9.99988672e-01, 1.13277278e-05],\n", |
835 |
| - " [5.18429014e-09, 6.87558131e-10, 9.99933920e-01, 6.60739062e-05],\n", |
836 |
| - " [5.37818923e-09, 3.33002212e-07, 9.99803469e-01, 1.96192448e-04],\n", |
837 |
| - " [1.03173145e-09, 6.08761519e-11, 9.98799245e-01, 1.20075404e-03],\n", |
838 |
| - " [7.17781813e-08, 1.07326334e-10, 9.99370198e-01, 6.29730479e-04],\n", |
839 |
| - " [1.17034725e-07, 6.11213703e-07, 1.03821363e-05, 9.99988890e-01],\n", |
840 |
| - " [5.60834135e-12, 6.34017167e-11, 1.59124098e-07, 9.99999841e-01],\n", |
841 |
| - " [3.27531509e-07, 1.30276882e-06, 2.81841700e-04, 9.99716528e-01],\n", |
842 |
| - " [2.85187496e-06, 4.27417539e-06, 2.25183478e-06, 9.99990622e-01],\n", |
843 |
| - " [2.26845682e-04, 2.14858622e-04, 1.30080904e-05, 9.99545288e-01],\n", |
844 |
| - " [1.91071927e-07, 2.18341367e-06, 6.99319392e-08, 9.99997556e-01],\n", |
845 |
| - " [1.52912405e-16, 7.51115551e-16, 5.97233580e-10, 9.99999999e-01],\n", |
846 |
| - " [4.28842325e-07, 3.58739061e-06, 1.15789937e-04, 9.99880194e-01],\n", |
847 |
| - " [1.24358496e-12, 1.61846276e-14, 4.09872243e-06, 9.99995901e-01],\n", |
848 |
| - " [1.79063038e-05, 2.35134166e-03, 1.45770845e-04, 9.97484981e-01],\n", |
849 |
| - " [4.94813568e-06, 1.94485993e-05, 5.28329480e-04, 9.99447274e-01],\n", |
850 |
| - " [4.32998150e-06, 2.40290270e-07, 9.96027246e-04, 9.98999402e-01],\n", |
851 |
| - " [6.32718815e-04, 5.04958392e-04, 6.85676138e-04, 9.98176647e-01]])" |
852 |
| - ] |
853 |
| - }, |
854 |
| - "execution_count": 52, |
855 |
| - "metadata": {}, |
856 |
| - "output_type": "execute_result" |
857 |
| - } |
858 |
| - ], |
859 |
| - "source": [ |
860 |
| - "clf.predict_proba(X)" |
861 |
| - ] |
| 798 | + "source": [] |
862 | 799 | },
|
863 | 800 | {
|
864 | 801 | "cell_type": "code",
|
|
882 | 819 | },
|
883 | 820 | {
|
884 | 821 | "cell_type": "code",
|
885 |
| - "execution_count": 54, |
| 822 | + "execution_count": 82, |
886 | 823 | "metadata": {},
|
887 | 824 | "outputs": [
|
888 | 825 | {
|
|
974 | 911 | " -1.38081190e-02, -9.82462294e-03, 1.39165390e-06]])"
|
975 | 912 | ]
|
976 | 913 | },
|
977 |
| - "execution_count": 54, |
| 914 | + "execution_count": 82, |
978 | 915 | "metadata": {},
|
979 | 916 | "output_type": "execute_result"
|
980 | 917 | }
|
|
983 | 920 | "clf.coef_"
|
984 | 921 | ]
|
985 | 922 | },
|
| 923 | + { |
| 924 | + "cell_type": "markdown", |
| 925 | + "metadata": {}, |
| 926 | + "source": [ |
| 927 | + "# K-fold cross-validation" |
| 928 | + ] |
| 929 | + }, |
| 930 | + { |
| 931 | + "cell_type": "code", |
| 932 | + "execution_count": 76, |
| 933 | + "metadata": {}, |
| 934 | + "outputs": [ |
| 935 | + { |
| 936 | + "data": { |
| 937 | + "text/html": [ |
| 938 | + "<div>\n", |
| 939 | + "<style scoped>\n", |
| 940 | + " .dataframe tbody tr th:only-of-type {\n", |
| 941 | + " vertical-align: middle;\n", |
| 942 | + " }\n", |
| 943 | + "\n", |
| 944 | + " .dataframe tbody tr th {\n", |
| 945 | + " vertical-align: top;\n", |
| 946 | + " }\n", |
| 947 | + "\n", |
| 948 | + " .dataframe thead th {\n", |
| 949 | + " text-align: right;\n", |
| 950 | + " }\n", |
| 951 | + "</style>\n", |
| 952 | + "<table border=\"1\" class=\"dataframe\">\n", |
| 953 | + " <thead>\n", |
| 954 | + " <tr style=\"text-align: right;\">\n", |
| 955 | + " <th></th>\n", |
| 956 | + " <th>macro</th>\n", |
| 957 | + " <th>micro</th>\n", |
| 958 | + " <th>weighted</th>\n", |
| 959 | + " </tr>\n", |
| 960 | + " </thead>\n", |
| 961 | + " <tbody>\n", |
| 962 | + " <tr>\n", |
| 963 | + " <th>0</th>\n", |
| 964 | + " <td>1.00000</td>\n", |
| 965 | + " <td>1.000000</td>\n", |
| 966 | + " <td>1.000000</td>\n", |
| 967 | + " </tr>\n", |
| 968 | + " <tr>\n", |
| 969 | + " <th>1</th>\n", |
| 970 | + " <td>1.00000</td>\n", |
| 971 | + " <td>1.000000</td>\n", |
| 972 | + " <td>1.000000</td>\n", |
| 973 | + " </tr>\n", |
| 974 | + " <tr>\n", |
| 975 | + " <th>2</th>\n", |
| 976 | + " <td>0.82500</td>\n", |
| 977 | + " <td>0.805556</td>\n", |
| 978 | + " <td>0.800000</td>\n", |
| 979 | + " </tr>\n", |
| 980 | + " <tr>\n", |
| 981 | + " <th>3</th>\n", |
| 982 | + " <td>0.85625</td>\n", |
| 983 | + " <td>0.833333</td>\n", |
| 984 | + " <td>0.841667</td>\n", |
| 985 | + " </tr>\n", |
| 986 | + " <tr>\n", |
| 987 | + " <th>4</th>\n", |
| 988 | + " <td>1.00000</td>\n", |
| 989 | + " <td>1.000000</td>\n", |
| 990 | + " <td>1.000000</td>\n", |
| 991 | + " </tr>\n", |
| 992 | + " <tr>\n", |
| 993 | + " <th>5</th>\n", |
| 994 | + " <td>1.00000</td>\n", |
| 995 | + " <td>1.000000</td>\n", |
| 996 | + " <td>1.000000</td>\n", |
| 997 | + " </tr>\n", |
| 998 | + " <tr>\n", |
| 999 | + " <th>6</th>\n", |
| 1000 | + " <td>1.00000</td>\n", |
| 1001 | + " <td>0.990741</td>\n", |
| 1002 | + " <td>1.000000</td>\n", |
| 1003 | + " </tr>\n", |
| 1004 | + " <tr>\n", |
| 1005 | + " <th>7</th>\n", |
| 1006 | + " <td>1.00000</td>\n", |
| 1007 | + " <td>1.000000</td>\n", |
| 1008 | + " <td>1.000000</td>\n", |
| 1009 | + " </tr>\n", |
| 1010 | + " <tr>\n", |
| 1011 | + " <th>8</th>\n", |
| 1012 | + " <td>0.96875</td>\n", |
| 1013 | + " <td>0.962963</td>\n", |
| 1014 | + " <td>0.958333</td>\n", |
| 1015 | + " </tr>\n", |
| 1016 | + " <tr>\n", |
| 1017 | + " <th>9</th>\n", |
| 1018 | + " <td>1.00000</td>\n", |
| 1019 | + " <td>1.000000</td>\n", |
| 1020 | + " <td>1.000000</td>\n", |
| 1021 | + " </tr>\n", |
| 1022 | + " </tbody>\n", |
| 1023 | + "</table>\n", |
| 1024 | + "</div>" |
| 1025 | + ], |
| 1026 | + "text/plain": [ |
| 1027 | + " macro micro weighted\n", |
| 1028 | + "0 1.00000 1.000000 1.000000\n", |
| 1029 | + "1 1.00000 1.000000 1.000000\n", |
| 1030 | + "2 0.82500 0.805556 0.800000\n", |
| 1031 | + "3 0.85625 0.833333 0.841667\n", |
| 1032 | + "4 1.00000 1.000000 1.000000\n", |
| 1033 | + "5 1.00000 1.000000 1.000000\n", |
| 1034 | + "6 1.00000 0.990741 1.000000\n", |
| 1035 | + "7 1.00000 1.000000 1.000000\n", |
| 1036 | + "8 0.96875 0.962963 0.958333\n", |
| 1037 | + "9 1.00000 1.000000 1.000000" |
| 1038 | + ] |
| 1039 | + }, |
| 1040 | + "execution_count": 76, |
| 1041 | + "metadata": {}, |
| 1042 | + "output_type": "execute_result" |
| 1043 | + } |
| 1044 | + ], |
| 1045 | + "source": [ |
| 1046 | + "from sklearn.model_selection import StratifiedKFold\n", |
| 1047 | + "from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score\n", |
| 1048 | + "from sklearn.preprocessing import label_binarize\n", |
| 1049 | + "\n", |
| 1050 | + "X = df[[c for c in df.columns if c != 'y']]\n", |
| 1051 | + "y = df.y\n", |
| 1052 | + "\n", |
| 1053 | + "skf = StratifiedKFold(n_splits=10)\n", |
| 1054 | + "\n", |
| 1055 | + "results = []\n", |
| 1056 | + "for train_index, test_index in skf.split(X, y):\n", |
| 1057 | + " X_train, y_train = X.iloc[train_index], y[train_index]\n", |
| 1058 | + " X_test, y_test = X.iloc[test_index], y[test_index]\n", |
| 1059 | + " \n", |
| 1060 | + " y_test = label_binarize(y_test, classes=[0.0, 1.0, 2.0, 3.0])\n", |
| 1061 | + " \n", |
| 1062 | + " clf = LogisticRegression(random_state=37, max_iter=1000).fit(X_train, y_train)\n", |
| 1063 | + " y_prob = clf.predict_proba(X_test)\n", |
| 1064 | + " \n", |
| 1065 | + " roc_macro = roc_auc_score(y_test, y_prob, average='macro')\n", |
| 1066 | + " roc_micro = roc_auc_score(y_test, y_prob, average='micro')\n", |
| 1067 | + " roc_weighted = roc_auc_score(y_test, y_prob, average='weighted')\n", |
| 1068 | + " results.append({\n", |
| 1069 | + " 'macro': roc_macro,\n", |
| 1070 | + " 'micro': roc_micro,\n", |
| 1071 | + " 'weighted': roc_weighted\n", |
| 1072 | + " })\n", |
| 1073 | + " \n", |
| 1074 | + "results = pd.DataFrame(results)\n", |
| 1075 | + "results" |
| 1076 | + ] |
| 1077 | + }, |
| 1078 | + { |
| 1079 | + "cell_type": "code", |
| 1080 | + "execution_count": 77, |
| 1081 | + "metadata": {}, |
| 1082 | + "outputs": [ |
| 1083 | + { |
| 1084 | + "data": { |
| 1085 | + "text/plain": [ |
| 1086 | + "macro 0.965000\n", |
| 1087 | + "micro 0.959259\n", |
| 1088 | + "weighted 0.960000\n", |
| 1089 | + "dtype: float64" |
| 1090 | + ] |
| 1091 | + }, |
| 1092 | + "execution_count": 77, |
| 1093 | + "metadata": {}, |
| 1094 | + "output_type": "execute_result" |
| 1095 | + } |
| 1096 | + ], |
| 1097 | + "source": [ |
| 1098 | + "results.mean()" |
| 1099 | + ] |
| 1100 | + }, |
| 1101 | + { |
| 1102 | + "cell_type": "markdown", |
| 1103 | + "metadata": {}, |
| 1104 | + "source": [ |
| 1105 | + "# Learn model on full data" |
| 1106 | + ] |
| 1107 | + }, |
| 1108 | + { |
| 1109 | + "cell_type": "code", |
| 1110 | + "execution_count": 78, |
| 1111 | + "metadata": {}, |
| 1112 | + "outputs": [], |
| 1113 | + "source": [ |
| 1114 | + "from sklearn.linear_model import LogisticRegression\n", |
| 1115 | + "\n", |
| 1116 | + "X = df[[c for c in df.columns if c != 'y']]\n", |
| 1117 | + "y = df.y\n", |
| 1118 | + "\n", |
| 1119 | + "clf = LogisticRegression(random_state=37, max_iter=1000).fit(X, y)" |
| 1120 | + ] |
| 1121 | + }, |
986 | 1122 | {
|
987 | 1123 | "cell_type": "code",
|
988 | 1124 | "execution_count": null,
|
|
0 commit comments