Commit 01b469f3 authored by Remi's avatar Remi

(3) Add N-grams

parent e6ced94e
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***Setup***"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.naive_bayes import MultinomialNB\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"from tabulate import tabulate\n",
"\n",
"import pandas as pd\n",
"\n",
"from constants import CONSTANTS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The NLP process we're practicing is using N-grams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***Some helper functions***"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def print_clf_report_as_table(report):\n",
" data = []\n",
" for key, value in report.items():\n",
" if key != 'accuracy' and key != 'macro avg' and key != 'weighted avg':\n",
" data.append([key, value['precision'], value['recall'], value['f1-score'], value['support']])\n",
"\n",
" data.append(['accuracy', '', '', report['accuracy'], ''])\n",
"\n",
" data.append(['macro avg', report['macro avg']['precision'], report['macro avg']['recall'], report['macro avg']['f1-score'], ''])\n",
"\n",
" data.append(['weighted avg', report['weighted avg']['precision'], report['weighted avg']['recall'], report['weighted avg']['f1-score'], ''])\n",
"\n",
" print(tabulate(data, headers=['Class', 'Precision', 'Recall', 'F1-score', 'Support'], tablefmt='psql'))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_prepared_data():\n",
" train_set = pd.read_csv(CONSTANTS.TRAIN_SET_PATH)\n",
" X_train = train_set['text']\n",
" y_train = train_set['label']\n",
" test_set = pd.read_csv(CONSTANTS.TEST_SET_PATH)\n",
" X_test = test_set['text']\n",
" y_test = test_set['label'] \n",
" return X_train, X_test, y_train, y_test "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_method(n):\n",
" # Read data\n",
" X_train, X_test, y_train, y_test = get_prepared_data()\n",
" # Apply TF_IDF\n",
" vectorizer = TfidfVectorizer(ngram_range=(n, n))\n",
" X_train = vectorizer.fit_transform(X_train)\n",
" X_test = vectorizer.transform(X_test)\n",
" # Train Naive bayes classifier\n",
" classifier = MultinomialNB()\n",
" classifier.fit(X_train, y_train) \n",
" # Evaluate model\n",
" # training set\n",
" y_train_pred = classifier.predict(X_train)\n",
" train_accuracy = accuracy_score(y_train, y_train_pred)\n",
"\n",
" # test set\n",
" y_pred = classifier.predict(X_test)\n",
" test_accuracy = accuracy_score(y_test, y_pred)\n",
"\n",
" print(f'Train Accuracy: {train_accuracy}')\n",
" print(f'Test Accuracy: {test_accuracy}')\n",
" # Print classification report\n",
" report = classification_report(y_test, y_pred, output_dict=True)\n",
" print_clf_report_as_table(report) \n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***1- 1Gram***"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.9971056439942113\n",
"Test Accuracy: 0.8982683982683982\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+--------------------+--------------------+------------+-----------|\n",
"| acne | 1.0 | 1.0 | 1 | 19.0 |\n",
"| allergy | 0.8888888888888888 | 0.8421052631578947 | 0.864865 | 19.0 |\n",
"| arthritis | 0.875 | 1.0 | 0.933333 | 14.0 |\n",
"| bronchial asthma | 0.68 | 1.0 | 0.809524 | 17.0 |\n",
"| cervical spondylosis | 1.0 | 1.0 | 1 | 21.0 |\n",
"| chicken pox | 0.8333333333333334 | 0.7894736842105263 | 0.810811 | 19.0 |\n",
"| common cold | 0.7777777777777778 | 0.7777777777777778 | 0.777778 | 18.0 |\n",
"| dengue | 0.5172413793103449 | 0.9375 | 0.666667 | 16.0 |\n",
"| diabetes | 1.0 | 0.631578947368421 | 0.774194 | 19.0 |\n",
"| dimorphic hemorrhoids | 1.0 | 1.0 | 1 | 17.0 |\n",
"| drug reaction | 0.875 | 0.875 | 0.875 | 16.0 |\n",
"| fungal infection | 0.782608695652174 | 1.0 | 0.878049 | 18.0 |\n",
"| gastroesophageal reflux disease | 0.8461538461538461 | 1.0 | 0.916667 | 22.0 |\n",
"| hypertension | 1.0 | 1.0 | 1 | 18.0 |\n",
"| impetigo | 1.0 | 0.84 | 0.913043 | 25.0 |\n",
"| jaundice | 1.0 | 1.0 | 1 | 17.0 |\n",
"| malaria | 1.0 | 1.0 | 1 | 23.0 |\n",
"| migraine | 1.0 | 0.8947368421052632 | 0.944444 | 19.0 |\n",
"| peptic ulcer disease | 1.0 | 0.7727272727272727 | 0.871795 | 22.0 |\n",
"| pneumonia | 1.0 | 0.8333333333333334 | 0.909091 | 24.0 |\n",
"| psoriasis | 1.0 | 0.7727272727272727 | 0.871795 | 22.0 |\n",
"| typhoid | 1.0 | 0.75 | 0.857143 | 24.0 |\n",
"| urinary tract infection | 0.8888888888888888 | 1.0 | 0.941176 | 16.0 |\n",
"| varicose veins | 0.9444444444444444 | 1.0 | 0.971429 | 17.0 |\n",
"| accuracy | | | 0.898268 | |\n",
"| macro avg | 0.9128890522687373 | 0.9048733497253233 | 0.89945 | |\n",
"| weighted avg | 0.9216194874590677 | 0.8982683982683982 | 0.900663 | |\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_method(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***2- 2Gram***"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.9985528219971056\n",
"Test Accuracy: 0.8852813852813853\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+--------------------+--------------------+------------+-----------|\n",
"| acne | 0.9473684210526315 | 0.9473684210526315 | 0.947368 | 19.0 |\n",
"| allergy | 0.7619047619047619 | 0.8421052631578947 | 0.8 | 19.0 |\n",
"| arthritis | 0.8235294117647058 | 1.0 | 0.903226 | 14.0 |\n",
"| bronchial asthma | 0.7083333333333334 | 1.0 | 0.829268 | 17.0 |\n",
"| cervical spondylosis | 0.9545454545454546 | 1.0 | 0.976744 | 21.0 |\n",
"| chicken pox | 0.7391304347826086 | 0.8947368421052632 | 0.809524 | 19.0 |\n",
"| common cold | 0.7619047619047619 | 0.8888888888888888 | 0.820513 | 18.0 |\n",
"| dengue | 0.5652173913043478 | 0.8125 | 0.666667 | 16.0 |\n",
"| diabetes | 1.0 | 0.631578947368421 | 0.774194 | 19.0 |\n",
"| dimorphic hemorrhoids | 0.9444444444444444 | 1.0 | 0.971429 | 17.0 |\n",
"| drug reaction | 0.8 | 0.75 | 0.774194 | 16.0 |\n",
"| fungal infection | 0.8181818181818182 | 1.0 | 0.9 | 18.0 |\n",
"| gastroesophageal reflux disease | 0.95 | 0.8636363636363636 | 0.904762 | 22.0 |\n",
"| hypertension | 0.8571428571428571 | 1.0 | 0.923077 | 18.0 |\n",
"| impetigo | 1.0 | 0.92 | 0.958333 | 25.0 |\n",
"| jaundice | 1.0 | 1.0 | 1 | 17.0 |\n",
"| malaria | 1.0 | 1.0 | 1 | 23.0 |\n",
"| migraine | 1.0 | 0.8947368421052632 | 0.944444 | 19.0 |\n",
"| peptic ulcer disease | 1.0 | 0.7272727272727273 | 0.842105 | 22.0 |\n",
"| pneumonia | 1.0 | 0.875 | 0.933333 | 24.0 |\n",
"| psoriasis | 1.0 | 0.7727272727272727 | 0.871795 | 22.0 |\n",
"| typhoid | 1.0 | 0.5833333333333334 | 0.736842 | 24.0 |\n",
"| urinary tract infection | 0.8421052631578947 | 1.0 | 0.914286 | 16.0 |\n",
"| varicose veins | 1.0 | 1.0 | 1 | 17.0 |\n",
"| accuracy | | | 0.885281 | |\n",
"| macro avg | 0.8947420147299842 | 0.8918285375686693 | 0.883421 | |\n",
"| weighted avg | 0.9053659082488493 | 0.8852813852813853 | 0.885116 | |\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_method(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***3- 3Gram***"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.9971056439942113\n",
"Test Accuracy: 0.8679653679653679\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+--------------------+--------------------+------------+-----------|\n",
"| acne | 0.9473684210526315 | 0.9473684210526315 | 0.947368 | 19.0 |\n",
"| allergy | 0.7619047619047619 | 0.8421052631578947 | 0.8 | 19.0 |\n",
"| arthritis | 0.875 | 1.0 | 0.933333 | 14.0 |\n",
"| bronchial asthma | 0.7083333333333334 | 1.0 | 0.829268 | 17.0 |\n",
"| cervical spondylosis | 1.0 | 1.0 | 1 | 21.0 |\n",
"| chicken pox | 0.7142857142857143 | 0.7894736842105263 | 0.75 | 19.0 |\n",
"| common cold | 0.8 | 0.8888888888888888 | 0.842105 | 18.0 |\n",
"| dengue | 0.4642857142857143 | 0.8125 | 0.590909 | 16.0 |\n",
"| diabetes | 1.0 | 0.6842105263157895 | 0.8125 | 19.0 |\n",
"| dimorphic hemorrhoids | 0.9444444444444444 | 1.0 | 0.971429 | 17.0 |\n",
"| drug reaction | 0.8333333333333334 | 0.625 | 0.714286 | 16.0 |\n",
"| fungal infection | 0.8571428571428571 | 1.0 | 0.923077 | 18.0 |\n",
"| gastroesophageal reflux disease | 1.0 | 0.8181818181818182 | 0.9 | 22.0 |\n",
"| hypertension | 0.8181818181818182 | 1.0 | 0.9 | 18.0 |\n",
"| impetigo | 1.0 | 0.92 | 0.958333 | 25.0 |\n",
"| jaundice | 0.9444444444444444 | 1.0 | 0.971429 | 17.0 |\n",
"| malaria | 1.0 | 0.9130434782608695 | 0.954545 | 23.0 |\n",
"| migraine | 0.85 | 0.8947368421052632 | 0.871795 | 19.0 |\n",
"| peptic ulcer disease | 1.0 | 0.7272727272727273 | 0.842105 | 22.0 |\n",
"| pneumonia | 1.0 | 0.8333333333333334 | 0.909091 | 24.0 |\n",
"| psoriasis | 1.0 | 0.8181818181818182 | 0.9 | 22.0 |\n",
"| typhoid | 1.0 | 0.5416666666666666 | 0.702703 | 24.0 |\n",
"| urinary tract infection | 0.75 | 0.9375 | 0.833333 | 16.0 |\n",
"| varicose veins | 0.8947368421052632 | 1.0 | 0.944444 | 17.0 |\n",
"| accuracy | | | 0.867965 | |\n",
"| macro avg | 0.8818109035214299 | 0.8747276444845095 | 0.866752 | |\n",
"| weighted avg | 0.8942179048894702 | 0.8679653679653679 | 0.86929 | |\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_method(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment