|
674 | 674 | },
|
675 | 675 | {
|
676 | 676 | "cell_type": "code",
|
677 |
| - "execution_count": 124, |
| 677 | + "execution_count": 50, |
678 | 678 | "metadata": {},
|
679 | 679 | "outputs": [],
|
680 | 680 | "source": [
|
|
683 | 683 | },
|
684 | 684 | {
|
685 | 685 | "cell_type": "code",
|
686 |
| - "execution_count": 145, |
| 686 | + "execution_count": 51, |
687 | 687 | "metadata": {},
|
688 | 688 | "outputs": [],
|
689 | 689 | "source": [
|
690 | 690 | "params = {\n",
|
691 | 691 | " 'module__dropout': [0, 0.5, 0.8],\n",
|
692 |
| - " 'max_epochs': [1],\n", |
| 692 | + " 'max_epochs': [2],\n", |
693 | 693 | " 'verbose': [False],\n",
|
| 694 | + " 'train_split': [False],\n", |
694 | 695 | "}"
|
695 | 696 | ]
|
696 | 697 | },
|
| 698 | + { |
| 699 | + "cell_type": "markdown", |
| 700 | + "metadata": {}, |
| 701 | + "source": [ |
| 702 | + "The parameter we are interested in here is the dropout rate. We want to see which of the values (no dropout, 50%, 80%) is the best choice for our network.\n", |
| 703 | + "\n", |
| 704 | + "Additionally:\n", |
| 705 | + "\n", |
| 706 | + "- We use only two epochs (`max_epochs: [2]`) for each `.fit` (only to reduce execution time, normally we wouldn't change this and possibly add an `EarlyStopping` callback).\n", |
| 707 | + "- Disable the network output (`verbose: [False]`)\n", |
| 708 | + "- Disable the internal train/validation split (`train_split: [False]`) since the grid search will do k-fold validation anyway" |
| 709 | + ] |
| 710 | + }, |
697 | 711 | {
|
698 | 712 | "cell_type": "code",
|
699 |
| - "execution_count": 151, |
| 713 | + "execution_count": 52, |
700 | 714 | "metadata": {},
|
701 | 715 | "outputs": [
|
702 | 716 | {
|
|
714 | 728 | },
|
715 | 729 | {
|
716 | 730 | "cell_type": "code",
|
717 |
| - "execution_count": 147, |
| 731 | + "execution_count": 53, |
718 | 732 | "metadata": {},
|
719 | 733 | "outputs": [],
|
720 | 734 | "source": [
|
|
723 | 737 | },
|
724 | 738 | {
|
725 | 739 | "cell_type": "code",
|
726 |
| - "execution_count": 148, |
| 740 | + "execution_count": 54, |
727 | 741 | "metadata": {},
|
728 | 742 | "outputs": [],
|
729 | 743 | "source": [
|
|
732 | 746 | },
|
733 | 747 | {
|
734 | 748 | "cell_type": "code",
|
735 |
| - "execution_count": 149, |
| 749 | + "execution_count": 55, |
736 | 750 | "metadata": {
|
737 | 751 | "scrolled": false
|
738 | 752 | },
|
|
748 | 762 | "name": "stderr",
|
749 | 763 | "output_type": "stream",
|
750 | 764 | "text": [
|
751 |
| - "/home/marian/anaconda3/envs/skorch/lib/python3.6/site-packages/sklearn/model_selection/_split.py:1943: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.\n", |
752 |
| - " warnings.warn(CV_WARNING, FutureWarning)\n", |
753 | 765 | "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
|
754 |
| - "[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 2.1min finished\n" |
| 766 | + "[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 7.0min finished\n" |
755 | 767 | ]
|
756 | 768 | },
|
757 | 769 | {
|
758 | 770 | "data": {
|
759 | 771 | "text/plain": [
|
760 |
| - "GridSearchCV(cv='warn', error_score='raise-deprecating',\n", |
| 772 | + "GridSearchCV(cv=3, error_score='raise-deprecating',\n", |
761 | 773 | " estimator=<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
|
762 | 774 | " module_=Cnn(\n",
|
763 | 775 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
|
|
769 | 781 | " ),\n",
|
770 | 782 | "),\n",
|
771 | 783 | " fit_params=None, iid='warn', n_jobs=None,\n",
|
772 |
| - " param_grid={'module__dropout': [0, 0.5, 0.8], 'max_epochs': [1], 'verbose': [False]},\n", |
| 784 | + " param_grid={'module__dropout': [0, 0.5, 0.8], 'max_epochs': [2], 'verbose': [False], 'train_split': [False]},\n", |
773 | 785 | " pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
|
774 | 786 | " scoring='accuracy', verbose=1)"
|
775 | 787 | ]
|
776 | 788 | },
|
777 |
| - "execution_count": 149, |
| 789 | + "execution_count": 55, |
778 | 790 | "metadata": {},
|
779 | 791 | "output_type": "execute_result"
|
780 | 792 | }
|
|
792 | 804 | },
|
793 | 805 | {
|
794 | 806 | "cell_type": "code",
|
795 |
| - "execution_count": 150, |
| 807 | + "execution_count": 56, |
796 | 808 | "metadata": {},
|
797 | 809 | "outputs": [
|
798 | 810 | {
|
799 | 811 | "data": {
|
800 | 812 | "text/plain": [
|
801 |
| - "{'max_epochs': 1, 'module__dropout': 0, 'verbose': False}" |
| 813 | + "{'max_epochs': 2, 'module__dropout': 0, 'train_split': False, 'verbose': False}" |
802 | 814 | ]
|
803 | 815 | },
|
804 |
| - "execution_count": 150, |
| 816 | + "execution_count": 56, |
805 | 817 | "metadata": {},
|
806 | 818 | "output_type": "execute_result"
|
807 | 819 | }
|
|
0 commit comments