Skip to content

Commit 79a0be2

Browse files
committed
add return_train_score
1 parent ed75dec commit 79a0be2

3 files changed

+7
-4
lines changed

01-introduction.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
],
305305
"source": [
306306
"import pandas as pd\n",
307-
"from IPython.display import display\n",
307+
"from IPython import display\n",
308308
"\n",
309309
"# create a simple dataset of people\n",
310310
"data = {'Name': [\"John\", \"Anna\", \"Peter\", \"Linda\"],\n",

05-model-evaluation-and-improvement.ipynb

+4-2
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,8 @@
677677
"source": [
678678
"from sklearn.model_selection import GridSearchCV\n",
679679
"from sklearn.svm import SVC\n",
680-
"grid_search = GridSearchCV(SVC(), param_grid, cv=5)"
680+
"grid_search = GridSearchCV(SVC(), param_grid, cv=5,\n",
681+
" return_train_score=True)"
681682
]
682683
},
683684
{
@@ -1045,7 +1046,8 @@
10451046
}
10461047
],
10471048
"source": [
1048-
"grid_search = GridSearchCV(SVC(), param_grid, cv=5)\n",
1049+
"grid_search = GridSearchCV(SVC(), param_grid, cv=5,\n",
1050+
" return_train_score=True)\n",
10491051
"grid_search.fit(X_train, y_train)\n",
10501052
"print(\"Best parameters: {}\".format(grid_search.best_params_))\n",
10511053
"print(\"Best cross-validation score: {:.2f}\".format(grid_search.best_score_))"

mglearn/plot_grid_search.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def plot_cross_val_selection():
1414

1515
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],
1616
'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
17-
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
17+
grid_search = GridSearchCV(SVC(), param_grid, cv=5,
18+
return_train_score=True)
1819
grid_search.fit(X_trainval, y_trainval)
1920
results = pd.DataFrame(grid_search.cv_results_)[15:]
2021

0 commit comments

Comments
 (0)