|
@@ -0,0 +1,267 @@
|
|
|
|
+---
|
|
|
|
+title: "Conditional Probability in R: Guided Project Solutions"
|
|
|
|
+output: html_document
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+```{r, warning = FALSE, message = FALSE }
|
|
|
|
+library(tidyverse)
|
|
|
|
+set.seed(1)
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+# Introduction
|
|
|
|
+
|
|
|
|
+This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
|
|
|
|
+
|
|
|
|
+# Data
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+spam = read.csv("./data/SMSSpamCollection", sep = "\t", header = F)
|
|
|
|
+colnames(spam) = c("label", "sms")
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
|
|
|
|
+
|
|
|
|
+# Dividing Up Into Training and Test Sets
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+n = nrow(spam)
|
|
|
|
+n.training = 2547
|
|
|
|
+n.cv = 318
|
|
|
|
+n.test = 319
|
|
|
|
+
|
|
|
|
+# Create the random indices for training set
|
|
|
|
+train.indices = sample(1:n, size = n.training, replace = FALSE)
|
|
|
|
+
|
|
|
|
+# Get indices not used by the training set
|
|
|
|
+remaining.indices = setdiff(1:n, train.indices)
|
|
|
|
+
|
|
|
|
+# Remaining indices are already randomized, just allocate correctly
|
|
|
|
+cv.indices = remaining.indices[1:318]
|
|
|
|
+test.indices = remaining.indices[319:length(remaining.indices)]
|
|
|
|
+
|
|
|
|
+# Use the indices to create each of the datasets
|
|
|
|
+spam.train = spam[train.indices,]
|
|
|
|
+spam.cv = spam[cv.indices,]
|
|
|
|
+spam.test = spam[test.indices,]
|
|
|
|
+
|
|
|
|
+# Sanity check: are the ratios of ham to spam relatively constant?
|
|
|
|
+print(mean(spam.train$label == "ham"))
|
|
|
|
+print(mean(spam.cv$label == "ham"))
|
|
|
|
+print(mean(spam.test$label == "ham"))
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+The number of ham messages in each dataset is relatively close to the original 87%. These datasets look good for future analysis.
|
|
|
|
+
|
|
|
|
+# Data Cleaning
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+# To lowercase, removal of punctuation
|
|
|
|
+tidy.train = spam.train %>%
|
|
|
|
+ mutate(
|
|
|
|
+ sms = tolower(sms),
|
|
|
|
+ sms = str_replace_all(sms, "[[:punct:]]", ""),
|
|
|
|
+ sms = str_replace_all(sms, "[[:digit:]]", " "),
|
|
|
|
+ sms = str_replace_all(sms, "[\u0094\u0092\n\t]", " ")
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+# Creating the vocabulary
|
|
|
|
+vocabulary = NULL
|
|
|
|
+messages = pull(tidy.train, sms)
|
|
|
|
+
|
|
|
|
+# Iterate through the messages and add to the vocabulary
|
|
|
|
+for (m in messages) {
|
|
|
|
+ words = str_split(m, " ")[[1]]
|
|
|
|
+ words = words[!words %in% ""]
|
|
|
|
+ vocabulary = c(vocabulary, words)
|
|
|
|
+}
|
|
|
|
+vocabulary = unique(vocabulary)
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+# Calculating Constants and Parameters
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+# Calculating Constants
|
|
|
|
+# Mean of a vector of logicals is a percentage
|
|
|
|
+p.spam = mean(tidy.train$label == "spam")
|
|
|
|
+p.ham = mean(tidy.train$label == "ham")
|
|
|
|
+
|
|
|
|
+# Isolate the spam and ham messages
|
|
|
|
+spam.messages = tidy.train %>%
|
|
|
|
+ filter(label == "spam") %>%
|
|
|
|
+ pull("sms")
|
|
|
|
+
|
|
|
|
+ham.messages = tidy.train %>%
|
|
|
|
+ filter(label == "ham") %>%
|
|
|
|
+ pull("sms")
|
|
|
|
+
|
|
|
|
+spam.words = NULL
|
|
|
|
+for (sm in spam.messages) {
|
|
|
|
+ words = str_split(sm, " ")[[1]]
|
|
|
|
+ spam.words = c(spam.words, words)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+ham.words = NULL
|
|
|
|
+for (hm in ham.messages) {
|
|
|
|
+ words = str_split(hm, " ")[[1]]
|
|
|
|
+ ham.words = c(ham.words, words)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+n.spam = length(unique(spam.words))
|
|
|
|
+n.ham = length(unique(ham.words))
|
|
|
|
+n.vocabulary = length(vocabulary)
|
|
|
|
+alpha = 1
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+# Calculating Parameters
|
|
|
|
+spam.counts = list()
|
|
|
|
+ham.counts = list()
|
|
|
|
+spam.probs = list()
|
|
|
|
+ham.probs = list()
|
|
|
|
+
|
|
|
|
+# This might take a while to run with so many words
|
|
|
|
+for (vocab in vocabulary) {
|
|
|
|
+
|
|
|
|
+ # Initialize the counts for the word
|
|
|
|
+ spam.counts[[vocab]] = 0
|
|
|
|
+ ham.counts[[vocab]] = 0
|
|
|
|
+
|
|
|
|
+ # Break up the message and count how many times word appears
|
|
|
|
+ for (sm in spam.messages) {
|
|
|
|
+ words = str_split(sm, " ")[[1]]
|
|
|
|
+ spam.counts[[vocab]] = spam.counts[[vocab]] + sum(words == vocab)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for (hm in ham.messages) {
|
|
|
|
+ words = str_split(hm, " ")[[1]]
|
|
|
|
+ ham.counts[[vocab]] = ham.counts[[vocab]] + sum(words == vocab)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # Use the counts to calculate the probability
|
|
|
|
+ spam.probs[[vocab]] = (spam.counts[[vocab]] + alpha) / (n.spam + alpha * n.vocabulary)
|
|
|
|
+ ham.probs[[vocab]] = (ham.counts[[vocab]] + alpha) / (n.ham + alpha * n.vocabulary)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+# Classifying New Messages
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+classify = function(message) {
|
|
|
|
+
|
|
|
|
+ # Initializing the probability product
|
|
|
|
+ p.spam.given.message = p.spam
|
|
|
|
+ p.ham.given.message = p.ham
|
|
|
|
+
|
|
|
|
+ # Splitting and cleaning the new message
|
|
|
|
+ clean.message = tolower(message)
|
|
|
|
+ clean.message = str_replace_all(clean.message, "[[:punct:]]", "")
|
|
|
|
+ clean.message = str_replace_all(clean.message, "[[:digit:]]", " ")
|
|
|
|
+ clean.message = str_replace_all(clean.message, "[\u0094\u0092\n\t]", " ")
|
|
|
|
+ words = str_split(clean.message, " ")[[1]]
|
|
|
|
+
|
|
|
|
+ for (word in words) {
|
|
|
|
+
|
|
|
|
+ # Extra check if word is not in vocabulary
|
|
|
|
+ wi.spam.prob = ifelse(word %in% vocabulary,
|
|
|
|
+ spam.probs[[word]],
|
|
|
|
+ 1)
|
|
|
|
+ wi.ham.prob = ifelse(word %in% vocabulary,
|
|
|
|
+ ham.probs[[word]],
|
|
|
|
+ 1)
|
|
|
|
+
|
|
|
|
+ p.spam.given.message = p.spam.given.message * wi.spam.prob
|
|
|
|
+ p.ham.given.message = p.ham.given.message * wi.ham.prob
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ result = case_when(
|
|
|
|
+ p.spam.given.message >= p.ham.given.message ~ "spam",
|
|
|
|
+ p.spam.given.message < p.ham.given.message ~ "ham")
|
|
|
|
+
|
|
|
|
+ return(result)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+final.train = tidy.train %>%
|
|
|
|
+ mutate(
|
|
|
|
+ prediction = unlist(map(sms, classify))
|
|
|
|
+ ) %>%
|
|
|
|
+ select(label, prediction, sms)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Results of classification on training
|
|
|
|
+confusion = table(final.train$label, final.train$prediction)
|
|
|
|
+accuracy = (confusion[1,1] + confusion[2,2]) / nrow(final.train)
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+Roughly, the classifier achieves about 97% accuracy on the training set. We aren't interested in how well the classifier performs with training data though, the classifier has already "seen" all of these messages.
|
|
|
|
+
|
|
|
|
+# Hyperparameter Tuning
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+alpha.grid = seq(0.1, 1, by = 0.1)
|
|
|
|
+cv.accuracy = NULL
|
|
|
|
+
|
|
|
|
+for (a in alpha.grid) {
|
|
|
|
+
|
|
|
|
+ spam.probs = list()
|
|
|
|
+ ham.probs = list()
|
|
|
|
+
|
|
|
|
+ # This might take a while to run with so many words
|
|
|
|
+ for (vocab in vocabulary) {
|
|
|
|
+
|
|
|
|
+ # Use the counts to calculate the probability
|
|
|
|
+ spam.probs[[vocab]] = (spam.counts[[vocab]] + a) / (n.spam + a * n.vocabulary)
|
|
|
|
+ ham.probs[[vocab]] = (ham.counts[[vocab]] + a) / (n.ham + a * n.vocabulary)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ cv = spam.cv %>%
|
|
|
|
+ mutate(
|
|
|
|
+ prediction = unlist(map(sms, classify))
|
|
|
|
+ ) %>%
|
|
|
|
+ select(label, prediction, sms)
|
|
|
|
+
|
|
|
|
+ confusion = table(cv$label, cv$prediction)
|
|
|
|
+ acc = (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
|
+ cv.accuracy = c(cv.accuracy, acc)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+cv.check = tibble(
|
|
|
|
+ alpha = alpha.grid,
|
|
|
|
+ accuracy = cv.accuracy
|
|
|
|
+)
|
|
|
|
+cv.check
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
|
|
|
|
+
|
|
|
|
+# Test Set Performance
|
|
|
|
+
|
|
|
|
+```{r}
|
|
|
|
+# Reestablishing the proper parameters
|
|
|
|
+optimal.alpha = 0.1
|
|
|
|
+for (a in alpha.grid) {
|
|
|
|
+
|
|
|
|
+ spam.probs = list()
|
|
|
|
+ ham.probs = list()
|
|
|
|
+
|
|
|
|
+ # This might take a while to run with so many words
|
|
|
|
+ for (vocab in vocabulary) {
|
|
|
|
+
|
|
|
|
+ # Use the counts to calculate the probability
|
|
|
|
+ spam.probs[[vocab]] = (spam.counts[[vocab]] + optimal.alpha) / (n.spam + optimal.alpha * n.vocabulary)
|
|
|
|
+ ham.probs[[vocab]] = (ham.counts[[vocab]] + optimal.alpha) / (n.ham + optimal.alpha * n.vocabulary)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+spam.test = spam.test %>%
|
|
|
|
+ mutate(
|
|
|
|
+ prediction = unlist(map(sms, classify))
|
|
|
|
+ ) %>%
|
|
|
|
+ select(label, prediction, sms)
|
|
|
|
+
|
|
|
|
+confusion = table(spam.test$label, spam.test$prediction)
|
|
|
|
+test.accuracy = (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
|
+test.accuracy
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+We've achieved an accuracy of 93% in the test set. Not bad!
|