|
| 1 | +import streamlit as st |
| 2 | +import pandas as pd |
| 3 | +import numpy as np |
| 4 | +from sklearn.svm import SVC |
| 5 | +from sklearn.linear_model import LogisticRegression |
| 6 | +from sklearn.ensemble import RandomForestClassifier |
| 7 | +from sklearn.preprocessing import LabelEncoder |
| 8 | +from sklearn.model_selection import train_test_split |
| 9 | +from sklearn.metrics import plot_confusion_matrix, plot_roc_curve, plot_precision_recall_curve |
| 10 | +from sklearn.metrics import precision_score, recall_score |
| 11 | + |
| 12 | +def main(): |
| 13 | + st.title("Binary Classification Web App") |
| 14 | + st.sidebar.title("Binary Classification Web App") |
| 15 | + st.markdown("Are your mushrooms edible or poisonous? 🍄") |
| 16 | + st.sidebar.markdown("Are your mushrooms edible or poisonous? 🍄") |
| 17 | + |
| 18 | + @st.cache(persist=True) |
| 19 | + def load_data(): |
| 20 | + data = pd.read_csv("C:\Users\SANJAY N T\Desktop\project\streamlit-ml\mushrooms.csv") |
| 21 | + labelencoder=LabelEncoder() |
| 22 | + for col in data.columns: |
| 23 | + data[col] = labelencoder.fit_transform(data[col]) |
| 24 | + return data |
| 25 | + |
| 26 | + @st.cache(persist=True) |
| 27 | + def split(df): |
| 28 | + y = df.type |
| 29 | + x = df.drop(columns=['type']) |
| 30 | + x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=0) |
| 31 | + return x_train, x_test, y_train, y_test |
| 32 | + |
| 33 | + def plot_metrics(metrics_list): |
| 34 | + if 'Confusion Matrix' in metrics_list: |
| 35 | + st.subheader("Confusion Matrix") |
| 36 | + plot_confusion_matrix(model, x_test, y_test, display_labels=class_names) |
| 37 | + st.pyplot() |
| 38 | + |
| 39 | + if 'ROC Curve' in metrics_list: |
| 40 | + st.subheader("ROC Curve") |
| 41 | + plot_roc_curve(model, x_test, y_test) |
| 42 | + st.pyplot() |
| 43 | + |
| 44 | + if 'Precision-Recall Curve' in metrics_list: |
| 45 | + st.subheader('Precision-Recall Curve') |
| 46 | + plot_precision_recall_curve(model, x_test, y_test) |
| 47 | + st.pyplot() |
| 48 | + |
| 49 | + df = load_data() |
| 50 | + class_names = ['edible', 'poisonous'] |
| 51 | + |
| 52 | + x_train, x_test, y_train, y_test = split(df) |
| 53 | + |
| 54 | + st.sidebar.subheader("Choose Classifier") |
| 55 | + classifier = st.sidebar.selectbox("Classifier", ("Support Vector Machine (SVM)", "Logistic Regression", "Random Forest")) |
| 56 | + |
| 57 | + if classifier == 'Support Vector Machine (SVM)': |
| 58 | + st.sidebar.subheader("Model Hyperparameters") |
| 59 | + #choose parameters |
| 60 | + C = st.sidebar.number_input("C (Regularization parameter)", 0.01, 10.0, step=0.01, key='C_SVM') |
| 61 | + kernel = st.sidebar.radio("Kernel", ("rbf", "linear"), key='kernel') |
| 62 | + gamma = st.sidebar.radio("Gamma (Kernel Coefficient)", ("scale", "auto"), key='gamma') |
| 63 | + |
| 64 | + metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix', 'ROC Curve', 'Precision-Recall Curve')) |
| 65 | + |
| 66 | + if st.sidebar.button("Classify", key='classify'): |
| 67 | + st.subheader("Support Vector Machine (SVM) Results") |
| 68 | + model = SVC(C=C, kernel=kernel, gamma=gamma) |
| 69 | + model.fit(x_train, y_train) |
| 70 | + accuracy = model.score(x_test, y_test) |
| 71 | + y_pred = model.predict(x_test) |
| 72 | + st.write("Accuracy: ", accuracy.round(2)) |
| 73 | + st.write("Precision: ", precision_score(y_test, y_pred, labels=class_names).round(2)) |
| 74 | + st.write("Recall: ", recall_score(y_test, y_pred, labels=class_names).round(2)) |
| 75 | + plot_metrics(metrics) |
| 76 | + |
| 77 | + if classifier == 'Logistic Regression': |
| 78 | + st.sidebar.subheader("Model Hyperparameters") |
| 79 | + C = st.sidebar.number_input("C (Regularization parameter)", 0.01, 10.0, step=0.01, key='C_LR') |
| 80 | + max_iter = st.sidebar.slider("Maximum number of iterations", 100, 500, key='max_iter') |
| 81 | + |
| 82 | + metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix', 'ROC Curve', 'Precision-Recall Curve')) |
| 83 | + |
| 84 | + if st.sidebar.button("Classify", key='classify'): |
| 85 | + st.subheader("Logistic Regression Results") |
| 86 | + model = LogisticRegression(C=C, penalty='l2', max_iter=max_iter) |
| 87 | + model.fit(x_train, y_train) |
| 88 | + accuracy = model.score(x_test, y_test) |
| 89 | + y_pred = model.predict(x_test) |
| 90 | + st.write("Accuracy: ", accuracy.round(2)) |
| 91 | + st.write("Precision: ", precision_score(y_test, y_pred, labels=class_names).round(2)) |
| 92 | + st.write("Recall: ", recall_score(y_test, y_pred, labels=class_names).round(2)) |
| 93 | + plot_metrics(metrics) |
| 94 | + |
| 95 | + if classifier == 'Random Forest': |
| 96 | + st.sidebar.subheader("Model Hyperparameters") |
| 97 | + n_estimators = st.sidebar.number_input("The number of trees in the forest", 100, 5000, step=10, key='n_estimators') |
| 98 | + max_depth = st.sidebar.number_input("The maximum depth of the tree", 1, 20, step=1, key='n_estimators') |
| 99 | + bootstrap = st.sidebar.radio("Bootstrap samples when building trees", ('True', 'False'), key='bootstrap') |
| 100 | + metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix', 'ROC Curve', 'Precision-Recall Curve')) |
| 101 | + |
| 102 | + if st.sidebar.button("Classify", key='classify'): |
| 103 | + st.subheader("Random Forest Results") |
| 104 | + model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, bootstrap=bootstrap, n_jobs=-1) |
| 105 | + model.fit(x_train, y_train) |
| 106 | + accuracy = model.score(x_test, y_test) |
| 107 | + y_pred = model.predict(x_test) |
| 108 | + st.write("Accuracy: ", accuracy.round(2)) |
| 109 | + st.write("Precision: ", precision_score(y_test, y_pred, labels=class_names).round(2)) |
| 110 | + st.write("Recall: ", recall_score(y_test, y_pred, labels=class_names).round(2)) |
| 111 | + plot_metrics(metrics) |
| 112 | + |
| 113 | + if st.sidebar.checkbox("Show raw data", False): |
| 114 | + st.subheader("Mushroom Data Set (Classification)") |
| 115 | + st.write(df) |
| 116 | + st.markdown("This [data set](https://archive.ics.uci.edu/ml/datasets/Mushroom) includes descriptions of hypothetical samples corresponding to 23 species of gilled mushrooms " |
| 117 | + "in the Agaricus and Lepiota Family (pp. 500-525). Each species is identified as definitely edible, definitely poisonous, " |
| 118 | + "or of unknown edibility and not recommended. This latter class was combined with the poisonous one.") |
| 119 | + |
| 120 | +if __name__ == '__main__': |
| 121 | + main() |
0 commit comments