Skip to content

Commit 6157f96

Browse files
committed
update
1 parent aa3b868 commit 6157f96

File tree

1 file changed

+226
-90
lines changed

1 file changed

+226
-90
lines changed

html/rps-tf/data-capture/python/model.ipynb

Lines changed: 226 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,30 @@
5454
"print(f'k2i={k2i}')"
5555
]
5656
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 73,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"data": {
64+
"text/plain": [
65+
"0 18\n",
66+
"1 16\n",
67+
"2 14\n",
68+
"3 13\n",
69+
"Name: y, dtype: int64"
70+
]
71+
},
72+
"execution_count": 73,
73+
"metadata": {},
74+
"output_type": "execute_result"
75+
}
76+
],
77+
"source": [
78+
"df.y.value_counts()"
79+
]
80+
},
5781
{
5882
"cell_type": "code",
5983
"execution_count": 50,
@@ -771,94 +795,7 @@
771795
"execution_count": 51,
772796
"metadata": {},
773797
"outputs": [],
774-
"source": [
775-
"from sklearn.linear_model import LogisticRegression\n",
776-
"\n",
777-
"X = df[[c for c in df.columns if c != 'y']]\n",
778-
"y = df.y\n",
779-
"\n",
780-
"clf = LogisticRegression(random_state=37, max_iter=1000).fit(X, y)"
781-
]
782-
},
783-
{
784-
"cell_type": "code",
785-
"execution_count": 52,
786-
"metadata": {},
787-
"outputs": [
788-
{
789-
"data": {
790-
"text/plain": [
791-
"array([[9.99857751e-01, 7.14828650e-05, 1.78647482e-07, 7.05878308e-05],\n",
792-
" [9.99999873e-01, 1.10509918e-09, 1.74693620e-16, 1.25734063e-07],\n",
793-
" [9.99996341e-01, 1.32614673e-09, 7.42270612e-12, 3.65778682e-06],\n",
794-
" [9.99990065e-01, 4.32103437e-08, 6.42518012e-10, 9.89129664e-06],\n",
795-
" [9.99969350e-01, 2.54127838e-07, 2.94577247e-10, 3.03959933e-05],\n",
796-
" [9.99960952e-01, 4.76835060e-07, 1.90433708e-06, 3.66669567e-05],\n",
797-
" [9.99971594e-01, 2.12520998e-05, 6.02742194e-09, 7.14765181e-06],\n",
798-
" [9.99749888e-01, 2.49903134e-04, 1.91304507e-09, 2.06685335e-07],\n",
799-
" [9.99953476e-01, 4.36434926e-05, 3.99399209e-09, 2.87674889e-06],\n",
800-
" [9.99974708e-01, 2.63393212e-08, 3.91232222e-10, 2.52653882e-05],\n",
801-
" [9.99956720e-01, 4.44109840e-09, 3.37111338e-08, 4.32422851e-05],\n",
802-
" [9.99988844e-01, 1.10322220e-05, 3.68991538e-10, 1.23561436e-07],\n",
803-
" [9.99992981e-01, 2.67520562e-09, 1.49354589e-08, 7.00089754e-06],\n",
804-
" [9.99999427e-01, 4.40712129e-09, 2.06789051e-07, 3.62172519e-07],\n",
805-
" [9.99999940e-01, 1.34161358e-11, 3.95076232e-11, 6.04124503e-08],\n",
806-
" [9.99960640e-01, 1.39934737e-05, 3.09996241e-06, 2.22668243e-05],\n",
807-
" [9.99799809e-01, 1.64692960e-06, 2.85038654e-07, 1.98259008e-04],\n",
808-
" [9.99603280e-01, 1.33109265e-05, 1.40506103e-06, 3.82003515e-04],\n",
809-
" [2.09229849e-12, 9.99999977e-01, 1.05025421e-16, 2.32760723e-08],\n",
810-
" [1.71961605e-12, 9.99998688e-01, 3.56516527e-17, 1.31154561e-06],\n",
811-
" [1.11761695e-12, 9.99999977e-01, 7.12318625e-17, 2.30887938e-08],\n",
812-
" [1.57326986e-12, 9.99998921e-01, 3.44324441e-17, 1.07925189e-06],\n",
813-
" [1.97076943e-10, 9.99948654e-01, 7.90461918e-13, 5.13462660e-05],\n",
814-
" [2.43709302e-05, 9.99331510e-01, 4.31265625e-04, 2.12853863e-04],\n",
815-
" [1.00831498e-11, 9.99997244e-01, 1.26767860e-16, 2.75565003e-06],\n",
816-
" [5.96857500e-05, 9.98668899e-01, 5.10298703e-12, 1.27141501e-03],\n",
817-
" [5.08951663e-10, 9.99998919e-01, 3.86519351e-14, 1.08070169e-06],\n",
818-
" [2.04403247e-09, 9.99995505e-01, 7.40928347e-13, 4.49265969e-06],\n",
819-
" [2.62203638e-06, 9.98512730e-01, 5.44955709e-07, 1.48410273e-03],\n",
820-
" [1.50683309e-10, 9.99985434e-01, 1.93992680e-13, 1.45662527e-05],\n",
821-
" [9.03486290e-18, 9.99999337e-01, 9.82848687e-28, 6.63374159e-07],\n",
822-
" [9.33108527e-08, 9.99996483e-01, 9.94055120e-13, 3.42327558e-06],\n",
823-
" [5.94780494e-06, 9.99925512e-01, 1.08401219e-11, 6.85402587e-05],\n",
824-
" [2.15986496e-05, 9.99952242e-01, 3.02417560e-09, 2.61558715e-05],\n",
825-
" [8.21561097e-12, 8.32749629e-16, 9.99998609e-01, 1.39052160e-06],\n",
826-
" [1.35601995e-12, 2.28917607e-15, 9.99999239e-01, 7.60802957e-07],\n",
827-
" [2.33754103e-04, 6.34382468e-05, 9.99557187e-01, 1.45620457e-04],\n",
828-
" [4.83552266e-11, 1.70808956e-14, 9.99994960e-01, 5.03955585e-06],\n",
829-
" [6.80387431e-08, 5.55408412e-11, 9.99880669e-01, 1.19262462e-04],\n",
830-
" [9.61223274e-06, 1.41735509e-07, 9.99521777e-01, 4.68468995e-04],\n",
831-
" [5.27033040e-13, 2.11060231e-16, 9.99938114e-01, 6.18855911e-05],\n",
832-
" [7.19939600e-19, 4.58281087e-24, 9.99999853e-01, 1.47432720e-07],\n",
833-
" [6.91717948e-10, 2.33101504e-12, 9.99863531e-01, 1.36468264e-04],\n",
834-
" [2.02543459e-15, 1.99977575e-18, 9.99988672e-01, 1.13277278e-05],\n",
835-
" [5.18429014e-09, 6.87558131e-10, 9.99933920e-01, 6.60739062e-05],\n",
836-
" [5.37818923e-09, 3.33002212e-07, 9.99803469e-01, 1.96192448e-04],\n",
837-
" [1.03173145e-09, 6.08761519e-11, 9.98799245e-01, 1.20075404e-03],\n",
838-
" [7.17781813e-08, 1.07326334e-10, 9.99370198e-01, 6.29730479e-04],\n",
839-
" [1.17034725e-07, 6.11213703e-07, 1.03821363e-05, 9.99988890e-01],\n",
840-
" [5.60834135e-12, 6.34017167e-11, 1.59124098e-07, 9.99999841e-01],\n",
841-
" [3.27531509e-07, 1.30276882e-06, 2.81841700e-04, 9.99716528e-01],\n",
842-
" [2.85187496e-06, 4.27417539e-06, 2.25183478e-06, 9.99990622e-01],\n",
843-
" [2.26845682e-04, 2.14858622e-04, 1.30080904e-05, 9.99545288e-01],\n",
844-
" [1.91071927e-07, 2.18341367e-06, 6.99319392e-08, 9.99997556e-01],\n",
845-
" [1.52912405e-16, 7.51115551e-16, 5.97233580e-10, 9.99999999e-01],\n",
846-
" [4.28842325e-07, 3.58739061e-06, 1.15789937e-04, 9.99880194e-01],\n",
847-
" [1.24358496e-12, 1.61846276e-14, 4.09872243e-06, 9.99995901e-01],\n",
848-
" [1.79063038e-05, 2.35134166e-03, 1.45770845e-04, 9.97484981e-01],\n",
849-
" [4.94813568e-06, 1.94485993e-05, 5.28329480e-04, 9.99447274e-01],\n",
850-
" [4.32998150e-06, 2.40290270e-07, 9.96027246e-04, 9.98999402e-01],\n",
851-
" [6.32718815e-04, 5.04958392e-04, 6.85676138e-04, 9.98176647e-01]])"
852-
]
853-
},
854-
"execution_count": 52,
855-
"metadata": {},
856-
"output_type": "execute_result"
857-
}
858-
],
859-
"source": [
860-
"clf.predict_proba(X)"
861-
]
798+
"source": []
862799
},
863800
{
864801
"cell_type": "code",
@@ -882,7 +819,7 @@
882819
},
883820
{
884821
"cell_type": "code",
885-
"execution_count": 54,
822+
"execution_count": 82,
886823
"metadata": {},
887824
"outputs": [
888825
{
@@ -974,7 +911,7 @@
974911
" -1.38081190e-02, -9.82462294e-03, 1.39165390e-06]])"
975912
]
976913
},
977-
"execution_count": 54,
914+
"execution_count": 82,
978915
"metadata": {},
979916
"output_type": "execute_result"
980917
}
@@ -983,6 +920,205 @@
983920
"clf.coef_"
984921
]
985922
},
923+
{
924+
"cell_type": "markdown",
925+
"metadata": {},
926+
"source": [
927+
"# K-fold cross-validation"
928+
]
929+
},
930+
{
931+
"cell_type": "code",
932+
"execution_count": 76,
933+
"metadata": {},
934+
"outputs": [
935+
{
936+
"data": {
937+
"text/html": [
938+
"<div>\n",
939+
"<style scoped>\n",
940+
" .dataframe tbody tr th:only-of-type {\n",
941+
" vertical-align: middle;\n",
942+
" }\n",
943+
"\n",
944+
" .dataframe tbody tr th {\n",
945+
" vertical-align: top;\n",
946+
" }\n",
947+
"\n",
948+
" .dataframe thead th {\n",
949+
" text-align: right;\n",
950+
" }\n",
951+
"</style>\n",
952+
"<table border=\"1\" class=\"dataframe\">\n",
953+
" <thead>\n",
954+
" <tr style=\"text-align: right;\">\n",
955+
" <th></th>\n",
956+
" <th>macro</th>\n",
957+
" <th>micro</th>\n",
958+
" <th>weighted</th>\n",
959+
" </tr>\n",
960+
" </thead>\n",
961+
" <tbody>\n",
962+
" <tr>\n",
963+
" <th>0</th>\n",
964+
" <td>1.00000</td>\n",
965+
" <td>1.000000</td>\n",
966+
" <td>1.000000</td>\n",
967+
" </tr>\n",
968+
" <tr>\n",
969+
" <th>1</th>\n",
970+
" <td>1.00000</td>\n",
971+
" <td>1.000000</td>\n",
972+
" <td>1.000000</td>\n",
973+
" </tr>\n",
974+
" <tr>\n",
975+
" <th>2</th>\n",
976+
" <td>0.82500</td>\n",
977+
" <td>0.805556</td>\n",
978+
" <td>0.800000</td>\n",
979+
" </tr>\n",
980+
" <tr>\n",
981+
" <th>3</th>\n",
982+
" <td>0.85625</td>\n",
983+
" <td>0.833333</td>\n",
984+
" <td>0.841667</td>\n",
985+
" </tr>\n",
986+
" <tr>\n",
987+
" <th>4</th>\n",
988+
" <td>1.00000</td>\n",
989+
" <td>1.000000</td>\n",
990+
" <td>1.000000</td>\n",
991+
" </tr>\n",
992+
" <tr>\n",
993+
" <th>5</th>\n",
994+
" <td>1.00000</td>\n",
995+
" <td>1.000000</td>\n",
996+
" <td>1.000000</td>\n",
997+
" </tr>\n",
998+
" <tr>\n",
999+
" <th>6</th>\n",
1000+
" <td>1.00000</td>\n",
1001+
" <td>0.990741</td>\n",
1002+
" <td>1.000000</td>\n",
1003+
" </tr>\n",
1004+
" <tr>\n",
1005+
" <th>7</th>\n",
1006+
" <td>1.00000</td>\n",
1007+
" <td>1.000000</td>\n",
1008+
" <td>1.000000</td>\n",
1009+
" </tr>\n",
1010+
" <tr>\n",
1011+
" <th>8</th>\n",
1012+
" <td>0.96875</td>\n",
1013+
" <td>0.962963</td>\n",
1014+
" <td>0.958333</td>\n",
1015+
" </tr>\n",
1016+
" <tr>\n",
1017+
" <th>9</th>\n",
1018+
" <td>1.00000</td>\n",
1019+
" <td>1.000000</td>\n",
1020+
" <td>1.000000</td>\n",
1021+
" </tr>\n",
1022+
" </tbody>\n",
1023+
"</table>\n",
1024+
"</div>"
1025+
],
1026+
"text/plain": [
1027+
" macro micro weighted\n",
1028+
"0 1.00000 1.000000 1.000000\n",
1029+
"1 1.00000 1.000000 1.000000\n",
1030+
"2 0.82500 0.805556 0.800000\n",
1031+
"3 0.85625 0.833333 0.841667\n",
1032+
"4 1.00000 1.000000 1.000000\n",
1033+
"5 1.00000 1.000000 1.000000\n",
1034+
"6 1.00000 0.990741 1.000000\n",
1035+
"7 1.00000 1.000000 1.000000\n",
1036+
"8 0.96875 0.962963 0.958333\n",
1037+
"9 1.00000 1.000000 1.000000"
1038+
]
1039+
},
1040+
"execution_count": 76,
1041+
"metadata": {},
1042+
"output_type": "execute_result"
1043+
}
1044+
],
1045+
"source": [
1046+
"from sklearn.model_selection import StratifiedKFold\n",
1047+
"from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score\n",
1048+
"from sklearn.preprocessing import label_binarize\n",
1049+
"\n",
1050+
"X = df[[c for c in df.columns if c != 'y']]\n",
1051+
"y = df.y\n",
1052+
"\n",
1053+
"skf = StratifiedKFold(n_splits=10)\n",
1054+
"\n",
1055+
"results = []\n",
1056+
"for train_index, test_index in skf.split(X, y):\n",
1057+
" X_train, y_train = X.iloc[train_index], y[train_index]\n",
1058+
" X_test, y_test = X.iloc[test_index], y[test_index]\n",
1059+
" \n",
1060+
" y_test = label_binarize(y_test, classes=[0.0, 1.0, 2.0, 3.0])\n",
1061+
" \n",
1062+
" clf = LogisticRegression(random_state=37, max_iter=1000).fit(X_train, y_train)\n",
1063+
" y_prob = clf.predict_proba(X_test)\n",
1064+
" \n",
1065+
" roc_macro = roc_auc_score(y_test, y_prob, average='macro')\n",
1066+
" roc_micro = roc_auc_score(y_test, y_prob, average='micro')\n",
1067+
" roc_weighted = roc_auc_score(y_test, y_prob, average='weighted')\n",
1068+
" results.append({\n",
1069+
" 'macro': roc_macro,\n",
1070+
" 'micro': roc_micro,\n",
1071+
" 'weighted': roc_weighted\n",
1072+
" })\n",
1073+
" \n",
1074+
"results = pd.DataFrame(results)\n",
1075+
"results"
1076+
]
1077+
},
1078+
{
1079+
"cell_type": "code",
1080+
"execution_count": 77,
1081+
"metadata": {},
1082+
"outputs": [
1083+
{
1084+
"data": {
1085+
"text/plain": [
1086+
"macro 0.965000\n",
1087+
"micro 0.959259\n",
1088+
"weighted 0.960000\n",
1089+
"dtype: float64"
1090+
]
1091+
},
1092+
"execution_count": 77,
1093+
"metadata": {},
1094+
"output_type": "execute_result"
1095+
}
1096+
],
1097+
"source": [
1098+
"results.mean()"
1099+
]
1100+
},
1101+
{
1102+
"cell_type": "markdown",
1103+
"metadata": {},
1104+
"source": [
1105+
"# Learn model on full data"
1106+
]
1107+
},
1108+
{
1109+
"cell_type": "code",
1110+
"execution_count": 78,
1111+
"metadata": {},
1112+
"outputs": [],
1113+
"source": [
1114+
"from sklearn.linear_model import LogisticRegression\n",
1115+
"\n",
1116+
"X = df[[c for c in df.columns if c != 'y']]\n",
1117+
"y = df.y\n",
1118+
"\n",
1119+
"clf = LogisticRegression(random_state=37, max_iter=1000).fit(X, y)"
1120+
]
1121+
},
9861122
{
9871123
"cell_type": "code",
9881124
"execution_count": null,

0 commit comments

Comments
 (0)