Commit b8e522e7 authored by Almouhannad Hafez's avatar Almouhannad Hafez

(3) Add POS Tagging filter

parent 7f83dc75
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***Setup***"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from nltk.stem import WordNetLemmatizer\n",
"from nltk.corpus import wordnet\n",
"from nltk.tokenize import word_tokenize\n",
"from nltk import pos_tag\n",
"from nltk.corpus import stopwords\n",
"\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.naive_bayes import MultinomialNB\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"from tabulate import tabulate\n",
"\n",
"import pandas as pd\n",
"\n",
"from constants import CONSTANTS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***Some helper functions***"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"lemmatizer = WordNetLemmatizer()\n",
"\n",
"# Function to get wordnet POS tag\n",
"def get_wordnet_pos(tag):\n",
" if tag.startswith('J'):\n",
" return wordnet.ADJ\n",
" elif tag.startswith('V'):\n",
" return wordnet.VERB\n",
" elif tag.startswith('N'):\n",
" return wordnet.NOUN\n",
" elif tag.startswith('R'):\n",
" return wordnet.ADV\n",
" else:\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def lemmatize_text_with_filter(text, filter = None):\n",
" tokens = word_tokenize(text)\n",
" pos_tags = pos_tag(tokens)\n",
" lemmatized_tokens = []\n",
" for word, tag in pos_tags:\n",
" wordnet_pos = get_wordnet_pos(tag) or wordnet.NOUN\n",
" if filter:\n",
" if wordnet_pos == filter: # Extract with filter only\n",
" lemma = lemmatizer.lemmatize(word, pos=wordnet_pos)\n",
" lemmatized_tokens.append(lemma)\n",
" else:\n",
" lemma = lemmatizer.lemmatize(word, pos=wordnet_pos)\n",
" lemmatized_tokens.append(lemma)\n",
" return \" \".join(lemmatized_tokens)\n",
"\n",
"def create_lemmatizer_with_filter(filter):\n",
" def lemmatizer_with_filter(text):\n",
" return lemmatize_text_with_filter(text, filter)\n",
" return lemmatizer_with_filter"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"stop_words = set(stopwords.words('english'))\n",
"\n",
"def remove_stopwords(text):\n",
" tokens = word_tokenize(text) # Tokenize the text\n",
" filtered_tokens = [word for word in tokens if word not in stop_words] # Remove stop words\n",
" return ' '.join(filtered_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def print_clf_report_as_table(report):\n",
" data = []\n",
" for key, value in report.items():\n",
" if key != 'accuracy' and key != 'macro avg' and key != 'weighted avg':\n",
" data.append([key, value['precision'], value['recall'], value['f1-score'], value['support']])\n",
"\n",
" data.append(['accuracy', '', '', report['accuracy'], ''])\n",
"\n",
" data.append(['macro avg', report['macro avg']['precision'], report['macro avg']['recall'], report['macro avg']['f1-score'], ''])\n",
"\n",
" data.append(['weighted avg', report['weighted avg']['precision'], report['weighted avg']['recall'], report['weighted avg']['f1-score'], ''])\n",
"\n",
" print(tabulate(data, headers=['Class', 'Precision', 'Recall', 'F1-score', 'Support'], tablefmt='psql'))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def get_prepared_data():\n",
" train_set = pd.read_csv(CONSTANTS.TRAIN_SET_PATH)\n",
" X_train = train_set['text']\n",
" y_train = train_set['label']\n",
" test_set = pd.read_csv(CONSTANTS.TEST_SET_PATH)\n",
" X_test = test_set['text']\n",
" y_test = test_set['label'] \n",
" X_train = X_train.apply(remove_stopwords)\n",
" X_test = X_test.apply(remove_stopwords) \n",
" return X_train, X_test, y_train, y_test "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_pos_filter(filter):\n",
" # Read data\n",
" X_train, X_test, y_train, y_test = get_prepared_data()\n",
" # Apply lemmatization with filter\n",
" lemmatize_filter = create_lemmatizer_with_filter(filter)\n",
" X_train = X_train.apply(lemmatize_filter)\n",
" X_test = X_test.apply(lemmatize_filter)\n",
" # Apply TF_IDF\n",
" vectorizer = TfidfVectorizer()\n",
" X_train = vectorizer.fit_transform(X_train)\n",
" X_test = vectorizer.transform(X_test)\n",
" # Train Naive bayes classifier\n",
" classifier = MultinomialNB()\n",
" classifier.fit(X_train, y_train) \n",
" # Evaluate model\n",
" # training set\n",
" y_train_pred = classifier.predict(X_train)\n",
" train_accuracy = accuracy_score(y_train, y_train_pred)\n",
"\n",
" # test set\n",
" y_pred = classifier.predict(X_test)\n",
" test_accuracy = accuracy_score(y_test, y_pred)\n",
"\n",
" print(f'Train Accuracy: {train_accuracy}')\n",
" print(f'Test Accuracy: {test_accuracy}')\n",
" # Print classification report\n",
" report = classification_report(y_test, y_pred, output_dict=True)\n",
" print_clf_report_as_table(report) \n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***1- Verbs only***"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.7814761215629522\n",
"Test Accuracy: 0.6082251082251082\n",
"+---------------------------------+---------------------+---------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+---------------------+---------------------+------------+-----------|\n",
"| acne | 0.8125 | 0.6842105263157895 | 0.742857 | 19.0 |\n",
"| allergy | 0.6666666666666666 | 0.5263157894736842 | 0.588235 | 19.0 |\n",
"| arthritis | 0.6666666666666666 | 1.0 | 0.8 | 14.0 |\n",
"| bronchial asthma | 0.4594594594594595 | 1.0 | 0.62963 | 17.0 |\n",
"| cervical spondylosis | 0.6190476190476191 | 0.6190476190476191 | 0.619048 | 21.0 |\n",
"| chicken pox | 0.5333333333333333 | 0.42105263157894735 | 0.470588 | 19.0 |\n",
"| common cold | 0.8235294117647058 | 0.7777777777777778 | 0.8 | 18.0 |\n",
"| dengue | 0.20408163265306123 | 0.625 | 0.307692 | 16.0 |\n",
"| diabetes | 0.6153846153846154 | 0.42105263157894735 | 0.5 | 19.0 |\n",
"| dimorphic hemorrhoids | 0.5833333333333334 | 0.8235294117647058 | 0.682927 | 17.0 |\n",
"| drug reaction | 0.4444444444444444 | 0.5 | 0.470588 | 16.0 |\n",
"| fungal infection | 0.7647058823529411 | 0.7222222222222222 | 0.742857 | 18.0 |\n",
"| gastroesophageal reflux disease | 0.7058823529411765 | 0.5454545454545454 | 0.615385 | 22.0 |\n",
"| hypertension | 0.6428571428571429 | 0.5 | 0.5625 | 18.0 |\n",
"| impetigo | 0.75 | 0.36 | 0.486486 | 25.0 |\n",
"| jaundice | 0.8666666666666667 | 0.7647058823529411 | 0.8125 | 17.0 |\n",
"| malaria | 0.8571428571428571 | 0.782608695652174 | 0.818182 | 23.0 |\n",
"| migraine | 0.7368421052631579 | 0.7368421052631579 | 0.736842 | 19.0 |\n",
"| peptic ulcer disease | 0.6666666666666666 | 0.45454545454545453 | 0.540541 | 22.0 |\n",
"| pneumonia | 0.8888888888888888 | 0.6666666666666666 | 0.761905 | 24.0 |\n",
"| psoriasis | 0.5882352941176471 | 0.45454545454545453 | 0.512821 | 22.0 |\n",
"| typhoid | 0.8333333333333334 | 0.20833333333333334 | 0.333333 | 24.0 |\n",
"| urinary tract infection | 0.43478260869565216 | 0.625 | 0.512821 | 16.0 |\n",
"| varicose veins | 0.5909090909090909 | 0.7647058823529411 | 0.666667 | 17.0 |\n",
"| accuracy | | | 0.608225 | |\n",
"| macro avg | 0.6564733363578803 | 0.624317359580265 | 0.6131 | |\n",
"| weighted avg | 0.6692573228781883 | 0.6082251082251082 | 0.610225 | |\n",
"+---------------------------------+---------------------+---------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_pos_filter(wordnet.VERB)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***2- Adjectives only***"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.8683068017366136\n",
"Test Accuracy: 0.6060606060606061\n",
"+---------------------------------+---------------------+---------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+---------------------+---------------------+------------+-----------|\n",
"| acne | 0.7391304347826086 | 0.8947368421052632 | 0.809524 | 19.0 |\n",
"| allergy | 0.75 | 0.47368421052631576 | 0.580645 | 19.0 |\n",
"| arthritis | 0.45161290322580644 | 1.0 | 0.622222 | 14.0 |\n",
"| bronchial asthma | 0.5357142857142857 | 0.8823529411764706 | 0.666667 | 17.0 |\n",
"| cervical spondylosis | 0.9333333333333333 | 0.6666666666666666 | 0.777778 | 21.0 |\n",
"| chicken pox | 0.35714285714285715 | 0.2631578947368421 | 0.30303 | 19.0 |\n",
"| common cold | 0.46153846153846156 | 0.3333333333333333 | 0.387097 | 18.0 |\n",
"| dengue | 0.15492957746478872 | 0.6875 | 0.252874 | 16.0 |\n",
"| diabetes | 0.4375 | 0.3684210526315789 | 0.4 | 19.0 |\n",
"| dimorphic hemorrhoids | 0.7857142857142857 | 0.6470588235294118 | 0.709677 | 17.0 |\n",
"| drug reaction | 0.8888888888888888 | 0.5 | 0.64 | 16.0 |\n",
"| fungal infection | 0.5925925925925926 | 0.8888888888888888 | 0.711111 | 18.0 |\n",
"| gastroesophageal reflux disease | 0.65 | 0.5909090909090909 | 0.619048 | 22.0 |\n",
"| hypertension | 0.8 | 0.4444444444444444 | 0.571429 | 18.0 |\n",
"| impetigo | 0.7857142857142857 | 0.44 | 0.564103 | 25.0 |\n",
"| jaundice | 0.9444444444444444 | 1.0 | 0.971429 | 17.0 |\n",
"| malaria | 0.9545454545454546 | 0.9130434782608695 | 0.933333 | 23.0 |\n",
"| migraine | 0.7727272727272727 | 0.8947368421052632 | 0.829268 | 19.0 |\n",
"| peptic ulcer disease | 0.8888888888888888 | 0.36363636363636365 | 0.516129 | 22.0 |\n",
"| pneumonia | 1.0 | 0.6666666666666666 | 0.8 | 24.0 |\n",
"| psoriasis | 0.5555555555555556 | 0.22727272727272727 | 0.322581 | 22.0 |\n",
"| typhoid | 0.6153846153846154 | 0.3333333333333333 | 0.432432 | 24.0 |\n",
"| urinary tract infection | 0.6086956521739131 | 0.875 | 0.717949 | 16.0 |\n",
"| varicose veins | 0.6923076923076923 | 0.5294117647058824 | 0.6 | 17.0 |\n",
"| accuracy | | | 0.606061 | |\n",
"| macro avg | 0.6815150617558347 | 0.6201773068720589 | 0.614097 | |\n",
"| weighted avg | 0.6938185269001016 | 0.6060606060606061 | 0.614131 | |\n",
"+---------------------------------+---------------------+---------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_pos_filter(wordnet.ADJ)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ***3- Nouns only***"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 0.9782923299565847\n",
"Test Accuracy: 0.8766233766233766\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n",
"| Class | Precision | Recall | F1-score | Support |\n",
"|---------------------------------+--------------------+--------------------+------------+-----------|\n",
"| acne | 0.95 | 1.0 | 0.974359 | 19.0 |\n",
"| allergy | 0.9230769230769231 | 0.631578947368421 | 0.75 | 19.0 |\n",
"| arthritis | 0.875 | 1.0 | 0.933333 | 14.0 |\n",
"| bronchial asthma | 0.7391304347826086 | 1.0 | 0.85 | 17.0 |\n",
"| cervical spondylosis | 0.9545454545454546 | 1.0 | 0.976744 | 21.0 |\n",
"| chicken pox | 0.8823529411764706 | 0.7894736842105263 | 0.833333 | 19.0 |\n",
"| common cold | 0.75 | 0.8333333333333334 | 0.789474 | 18.0 |\n",
"| dengue | 0.4827586206896552 | 0.875 | 0.622222 | 16.0 |\n",
"| diabetes | 0.9285714285714286 | 0.6842105263157895 | 0.787879 | 19.0 |\n",
"| dimorphic hemorrhoids | 1.0 | 1.0 | 1 | 17.0 |\n",
"| drug reaction | 0.8181818181818182 | 0.5625 | 0.666667 | 16.0 |\n",
"| fungal infection | 0.8571428571428571 | 1.0 | 0.923077 | 18.0 |\n",
"| gastroesophageal reflux disease | 0.8 | 0.9090909090909091 | 0.851064 | 22.0 |\n",
"| hypertension | 0.8947368421052632 | 0.9444444444444444 | 0.918919 | 18.0 |\n",
"| impetigo | 1.0 | 0.92 | 0.958333 | 25.0 |\n",
"| jaundice | 1.0 | 1.0 | 1 | 17.0 |\n",
"| malaria | 0.9583333333333334 | 1.0 | 0.978723 | 23.0 |\n",
"| migraine | 1.0 | 0.8947368421052632 | 0.944444 | 19.0 |\n",
"| peptic ulcer disease | 1.0 | 0.7272727272727273 | 0.842105 | 22.0 |\n",
"| pneumonia | 0.92 | 0.9583333333333334 | 0.938776 | 24.0 |\n",
"| psoriasis | 1.0 | 0.8636363636363636 | 0.926829 | 22.0 |\n",
"| typhoid | 0.8125 | 0.5416666666666666 | 0.65 | 24.0 |\n",
"| urinary tract infection | 0.8421052631578947 | 1.0 | 0.914286 | 16.0 |\n",
"| varicose veins | 0.8947368421052632 | 1.0 | 0.944444 | 17.0 |\n",
"| accuracy | | | 0.876623 | |\n",
"| macro avg | 0.8867988649528739 | 0.8806365740740741 | 0.873959 | |\n",
"| weighted avg | 0.8930983880991834 | 0.8766233766233766 | 0.875435 | |\n",
"+---------------------------------+--------------------+--------------------+------------+-----------+\n"
]
}
],
"source": [
"evaluate_pos_filter(wordnet.NOUN)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "NLP",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment