|
@@ -6,230 +6,281 @@ output: html_document
|
|
```{r, warning = FALSE, message = FALSE }
|
|
```{r, warning = FALSE, message = FALSE }
|
|
library(tidyverse)
|
|
library(tidyverse)
|
|
set.seed(1)
|
|
set.seed(1)
|
|
|
|
+options(dplyr.summarise.inform = FALSE)
|
|
```
|
|
```
|
|
|
|
|
|
# Introduction
|
|
# Introduction
|
|
|
|
|
|
This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
|
|
This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
|
|
|
|
|
|
-# Data
|
|
|
|
-
|
|
|
|
```{r}
|
|
```{r}
|
|
-spam = read.csv("./data/SMSSpamCollection", sep = "\t", header = F)
|
|
|
|
-colnames(spam) = c("label", "sms")
|
|
|
|
|
|
+# Bring in the dataset
|
|
|
|
+spam <- read_csv("spam.csv")
|
|
```
|
|
```
|
|
|
|
|
|
The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
|
|
The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
|
|
|
|
|
|
-# Dividing Up Into Training and Test Sets
|
|
|
|
|
|
+# Training, Cross-validation and Test Sets
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-n = nrow(spam)
|
|
|
|
-n.training = 2547
|
|
|
|
-n.cv = 318
|
|
|
|
-n.test = 319
|
|
|
|
|
|
+# Calculate some helper values to split the dataset
|
|
|
|
+n <- nrow(spam)
|
|
|
|
+n_training <- 0.8 * n
|
|
|
|
+n_cv <- 0.1 * n
|
|
|
|
+n_test <- 0.1 * n
|
|
|
|
|
|
# Create the random indices for training set
|
|
# Create the random indices for training set
|
|
-train.indices = sample(1:n, size = n.training, replace = FALSE)
|
|
|
|
|
|
+train_indices <- sample(1:n, size = n_training, replace = FALSE)
|
|
|
|
|
|
# Get indices not used by the training set
|
|
# Get indices not used by the training set
|
|
-remaining.indices = setdiff(1:n, train.indices)
|
|
|
|
|
|
+remaining_indices <- setdiff(1:n, train_indices)
|
|
|
|
|
|
# Remaining indices are already randomized, just allocate correctly
|
|
# Remaining indices are already randomized, just allocate correctly
|
|
-cv.indices = remaining.indices[1:318]
|
|
|
|
-test.indices = remaining.indices[319:length(remaining.indices)]
|
|
|
|
|
|
+cv_indices <- remaining_indices[1:(length(remaining_indices)/2)]
|
|
|
|
+test_indices <- remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
|
|
|
|
|
|
# Use the indices to create each of the datasets
|
|
# Use the indices to create each of the datasets
|
|
-spam.train = spam[train.indices,]
|
|
|
|
-spam.cv = spam[cv.indices,]
|
|
|
|
-spam.test = spam[test.indices,]
|
|
|
|
|
|
+spam_train <- spam[train_indices,]
|
|
|
|
+spam_cv <- spam[cv_indices,]
|
|
|
|
+spam_test <- spam[test_indices,]
|
|
|
|
|
|
# Sanity check: are the ratios of ham to spam relatively constant?
|
|
# Sanity check: are the ratios of ham to spam relatively constant?
|
|
-print(mean(spam.train$label == "ham"))
|
|
|
|
-print(mean(spam.cv$label == "ham"))
|
|
|
|
-print(mean(spam.test$label == "ham"))
|
|
|
|
|
|
+print(mean(spam_train$label == "ham"))
|
|
|
|
+print(mean(spam_cv$label == "ham"))
|
|
|
|
+print(mean(spam_test$label == "ham"))
|
|
```
|
|
```
|
|
|
|
|
|
-The number of ham messages in each dataset is relatively close to the original 87%. These datasets look good for future analysis.
|
|
|
|
|
|
+The number of ham messages in each dataset is relatively close to each other in each dataset. This is just to make sure that no dataset is entirely just "ham", which ruins the point of spam detection.
|
|
|
|
|
|
# Data Cleaning
|
|
# Data Cleaning
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-# To lowercase, removal of punctuation
|
|
|
|
-tidy.train = spam.train %>%
|
|
|
|
|
|
+# To lowercase, removal of punctuation, weird characters, digits
|
|
|
|
+tidy_train <- spam_train %>%
|
|
mutate(
|
|
mutate(
|
|
- sms = tolower(sms),
|
|
|
|
- sms = str_replace_all(sms, "[[:punct:]]", ""),
|
|
|
|
- sms = str_replace_all(sms, "[[:digit:]]", " "),
|
|
|
|
- sms = str_replace_all(sms, "[\u0094\u0092\n\t]", " ")
|
|
|
|
|
|
+ # Take the messages and remove unwanted characters
|
|
|
|
+ sms = str_to_lower(sms) %>%
|
|
|
|
+ str_squish %>%
|
|
|
|
+ str_replace_all("[[:punct:]]", "") %>%
|
|
|
|
+ str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
|
|
|
|
+ str_replace_all("[[:digit:]]", "")
|
|
)
|
|
)
|
|
|
|
|
|
# Creating the vocabulary
|
|
# Creating the vocabulary
|
|
-vocabulary = NULL
|
|
|
|
-messages = pull(tidy.train, sms)
|
|
|
|
|
|
+vocabulary <- NULL
|
|
|
|
+messages <- tidy_train %>% pull(sms)
|
|
|
|
|
|
# Iterate through the messages and add to the vocabulary
|
|
# Iterate through the messages and add to the vocabulary
|
|
for (m in messages) {
|
|
for (m in messages) {
|
|
- words = str_split(m, " ")[[1]]
|
|
|
|
- words = words[!words %in% ""]
|
|
|
|
- vocabulary = c(vocabulary, words)
|
|
|
|
|
|
+ words <- str_split(m, " ")[[1]]
|
|
|
|
+ vocabulary <- c(vocabulary, words)
|
|
}
|
|
}
|
|
-vocabulary = unique(vocabulary)
|
|
|
|
|
|
+
|
|
|
|
+# Remove duplicates from the vocabulary
|
|
|
|
+vocabulary <- vocabulary %>% unique()
|
|
```
|
|
```
|
|
|
|
|
|
# Calculating Constants and Parameters
|
|
# Calculating Constants and Parameters
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-# Calculating Constants
|
|
|
|
-# Mean of a vector of logicals is a percentage
|
|
|
|
-p.spam = mean(tidy.train$label == "spam")
|
|
|
|
-p.ham = mean(tidy.train$label == "ham")
|
|
|
|
-
|
|
|
|
# Isolate the spam and ham messages
|
|
# Isolate the spam and ham messages
|
|
-spam.messages = tidy.train %>%
|
|
|
|
|
|
+spam_messages <- tidy_train %>%
|
|
filter(label == "spam") %>%
|
|
filter(label == "spam") %>%
|
|
- pull("sms")
|
|
|
|
|
|
+ pull(sms)
|
|
|
|
|
|
-ham.messages = tidy.train %>%
|
|
|
|
|
|
+ham_messages <- tidy_train %>%
|
|
filter(label == "ham") %>%
|
|
filter(label == "ham") %>%
|
|
- pull("sms")
|
|
|
|
|
|
+ pull(sms)
|
|
|
|
|
|
-spam.words = NULL
|
|
|
|
-for (sm in spam.messages) {
|
|
|
|
- words = str_split(sm, " ")[[1]]
|
|
|
|
- spam.words = c(spam.words, words)
|
|
|
|
|
|
+# Isolate the vocabulary in spam and ham messages
|
|
|
|
+spam_vocab <- NULL
|
|
|
|
+for (sm in spam_messages) {
|
|
|
|
+ words <- str_split(sm, " ")[[1]]
|
|
|
|
+ spam_vocab <- c(spam_vocab, words)
|
|
}
|
|
}
|
|
|
|
+spam_vocab <- spam_vocab %>% unique
|
|
|
|
|
|
-ham.words = NULL
|
|
|
|
-for (hm in ham.messages) {
|
|
|
|
- words = str_split(hm, " ")[[1]]
|
|
|
|
- ham.words = c(ham.words, words)
|
|
|
|
|
|
+ham_vocab <- NULL
|
|
|
|
+for (hm in ham_messages) {
|
|
|
|
+ words <- str_split(hm, " ")[[1]]
|
|
|
|
+ ham_vocab <- c(ham_vocab, words)
|
|
}
|
|
}
|
|
|
|
+ham_vocab <- ham_vocab %>% unique()
|
|
|
|
|
|
-n.spam = length(unique(spam.words))
|
|
|
|
-n.ham = length(unique(ham.words))
|
|
|
|
-n.vocabulary = length(vocabulary)
|
|
|
|
-alpha = 1
|
|
|
|
|
|
+# Calculate some important parameters from the vocab
|
|
|
|
+n_spam <- spam_vocab %>% length()
|
|
|
|
+n_ham <- ham_vocab %>% length()
|
|
|
|
+n_vocabulary <- vocabulary %>% length()
|
|
```
|
|
```
|
|
|
|
|
|
|
|
+# Calculating Probability Parameters
|
|
|
|
+
|
|
```{r}
|
|
```{r}
|
|
-# Calculating Parameters
|
|
|
|
-spam.counts = list()
|
|
|
|
-ham.counts = list()
|
|
|
|
-spam.probs = list()
|
|
|
|
-ham.probs = list()
|
|
|
|
-
|
|
|
|
-# This might take a while to run with so many words
|
|
|
|
-for (vocab in vocabulary) {
|
|
|
|
-
|
|
|
|
- # Initialize the counts for the word
|
|
|
|
- spam.counts[[vocab]] = 0
|
|
|
|
- ham.counts[[vocab]] = 0
|
|
|
|
-
|
|
|
|
- # Break up the message and count how many times word appears
|
|
|
|
- for (sm in spam.messages) {
|
|
|
|
- words = str_split(sm, " ")[[1]]
|
|
|
|
- spam.counts[[vocab]] = spam.counts[[vocab]] + sum(words == vocab)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- for (hm in ham.messages) {
|
|
|
|
- words = str_split(hm, " ")[[1]]
|
|
|
|
- ham.counts[[vocab]] = ham.counts[[vocab]] + sum(words == vocab)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- # Use the counts to calculate the probability
|
|
|
|
- spam.probs[[vocab]] = (spam.counts[[vocab]] + alpha) / (n.spam + alpha * n.vocabulary)
|
|
|
|
- ham.probs[[vocab]] = (ham.counts[[vocab]] + alpha) / (n.ham + alpha * n.vocabulary)
|
|
|
|
-}
|
|
|
|
|
|
+# New vectorized approach to a calculating ham and spam probabilities
|
|
|
|
|
|
|
|
+# Marginal probability of a training message being spam or ham
|
|
|
|
+p_spam <- mean(tidy_train$label == "spam")
|
|
|
|
+p_ham <- mean(tidy_train$label == "ham")
|
|
|
|
+
|
|
|
|
+# Break up the spam and ham counting into their own tibbles
|
|
|
|
+spam_counts <- tibble(
|
|
|
|
+ word = spam_vocab
|
|
|
|
+) %>%
|
|
|
|
+ mutate(
|
|
|
|
+ # Calculate the number of times a word appears in spam
|
|
|
|
+ spam_count = map_int(word, function(w) {
|
|
|
|
+
|
|
|
|
+ # Count how many times each word appears in all spam messsages, then sum
|
|
|
|
+ map_int(spam_messages, function(sm) {
|
|
|
|
+ (str_split(sm, " ")[[1]] == w) %>% sum # for a single message
|
|
|
|
+ }) %>%
|
|
|
|
+ sum # then summing over all messages
|
|
|
|
+
|
|
|
|
+ })
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+# There are many words in the ham vocabulary so this will take a while!
|
|
|
|
+# Run this code and distract yourself while the counts are calculated
|
|
|
|
+ham_counts <- tibble(
|
|
|
|
+ word = ham_vocab
|
|
|
|
+) %>%
|
|
|
|
+ mutate(
|
|
|
|
+ # Calculate the number of times a word appears in ham
|
|
|
|
+ ham_count = map_int(word, function(w) {
|
|
|
|
+
|
|
|
|
+ # Count how many times each word appears in all ham messsages, then sum
|
|
|
|
+ map_int(ham_messages, function(hm) {
|
|
|
|
+ (str_split(hm, " ")[[1]] == w) %>% sum
|
|
|
|
+ }) %>%
|
|
|
|
+ sum
|
|
|
|
+
|
|
|
|
+ })
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+# Join these tibbles together
|
|
|
|
+word_counts <- full_join(spam_counts, ham_counts, by = "word") %>%
|
|
|
|
+ mutate(
|
|
|
|
+ # Fill in zeroes where there are missing values
|
|
|
|
+ spam_count = ifelse(is.na(spam_count), 0, spam_count),
|
|
|
|
+ ham_count = ifelse(is.na(ham_count), 0, ham_count)
|
|
|
|
+ )
|
|
```
|
|
```
|
|
|
|
|
|
|
|
+
|
|
# Classifying New Messages
|
|
# Classifying New Messages
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-classify = function(message) {
|
|
|
|
-
|
|
|
|
- # Initializing the probability product
|
|
|
|
- p.spam.given.message = p.spam
|
|
|
|
- p.ham.given.message = p.ham
|
|
|
|
|
|
+# This is the updated function using the vectorized approach to calculating
|
|
|
|
+# the spam and ham probabilities
|
|
|
|
+
|
|
|
|
+# Create a function that makes it easy to classify a tibble of messages
|
|
|
|
+# we add an alpha argument to make it easy to recalculate probabilities
|
|
|
|
+# based on this alpha (default to 1)
|
|
|
|
+classify <- function(message, alpha = 1) {
|
|
|
|
|
|
# Splitting and cleaning the new message
|
|
# Splitting and cleaning the new message
|
|
- clean.message = tolower(message)
|
|
|
|
- clean.message = str_replace_all(clean.message, "[[:punct:]]", "")
|
|
|
|
- clean.message = str_replace_all(clean.message, "[[:digit:]]", " ")
|
|
|
|
- clean.message = str_replace_all(clean.message, "[\u0094\u0092\n\t]", " ")
|
|
|
|
- words = str_split(clean.message, " ")[[1]]
|
|
|
|
|
|
+ # This is the same cleaning procedure used on the training messages
|
|
|
|
+ clean_message <- str_to_lower(message) %>%
|
|
|
|
+ str_squish %>%
|
|
|
|
+ str_replace_all("[[:punct:]]", "") %>%
|
|
|
|
+ str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
|
|
|
|
+ str_replace_all("[[:digit:]]", "")
|
|
|
|
+
|
|
|
|
+ words <- str_split(clean_message, " ")[[1]]
|
|
|
|
+
|
|
|
|
+ # There is a possibility that there will be words that don't appear
|
|
|
|
+ # in the training vocabulary, so this must be accounted for
|
|
|
|
|
|
- for (word in words) {
|
|
|
|
-
|
|
|
|
- # Extra check if word is not in vocabulary
|
|
|
|
- wi.spam.prob = ifelse(word %in% vocabulary,
|
|
|
|
- spam.probs[[word]],
|
|
|
|
- 1)
|
|
|
|
- wi.ham.prob = ifelse(word %in% vocabulary,
|
|
|
|
- ham.probs[[word]],
|
|
|
|
- 1)
|
|
|
|
-
|
|
|
|
- p.spam.given.message = p.spam.given.message * wi.spam.prob
|
|
|
|
- p.ham.given.message = p.ham.given.message * wi.ham.prob
|
|
|
|
- }
|
|
|
|
|
|
+ # Find the words that aren't present in the training
|
|
|
|
+ new_words <- setdiff(vocabulary, words)
|
|
|
|
|
|
- result = case_when(
|
|
|
|
- p.spam.given.message >= p.ham.given.message ~ "spam",
|
|
|
|
- p.spam.given.message < p.ham.given.message ~ "ham")
|
|
|
|
|
|
+ # Add them to the word_counts
|
|
|
|
+ new_word_probs <- tibble(
|
|
|
|
+ word = new_words,
|
|
|
|
+ spam_prob = 1,
|
|
|
|
+ ham_prob = 1
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Filter down the probabilities to the words present
|
|
|
|
+ # use group by to multiply everything together
|
|
|
|
+ present_probs <- word_counts %>%
|
|
|
|
+ filter(word %in% words) %>%
|
|
|
|
+ mutate(
|
|
|
|
+ # Calculate the probabilities from the counts
|
|
|
|
+ spam_prob = (spam_count + alpha) / (n_spam + alpha * n_vocabulary),
|
|
|
|
+ ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
|
|
|
|
+ ) %>%
|
|
|
|
+ bind_rows(new_word_probs) %>%
|
|
|
|
+ pivot_longer(
|
|
|
|
+ cols = c("spam_prob", "ham_prob"),
|
|
|
|
+ names_to = "label",
|
|
|
|
+ values_to = "prob"
|
|
|
|
+ ) %>%
|
|
|
|
+ group_by(label) %>%
|
|
|
|
+ summarize(
|
|
|
|
+ wi_prob = prod(prob) # prod is like sum, but with multiplication
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Calculate the conditional probabilities
|
|
|
|
+ p_spam_given_message <- p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
|
|
|
|
+ p_ham_given_message <- p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
|
|
|
|
|
|
- return(result)
|
|
|
|
|
|
+ # Classify the message based on the probability
|
|
|
|
+ ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
|
|
}
|
|
}
|
|
|
|
|
|
-final.train = tidy.train %>%
|
|
|
|
|
|
+# Use the classify function to classify the messages in the training set
|
|
|
|
+# This takes advantage of vectorization
|
|
|
|
+final_train <- tidy_train %>%
|
|
mutate(
|
|
mutate(
|
|
- prediction = unlist(map(sms, classify))
|
|
|
|
- ) %>%
|
|
|
|
- select(label, prediction, sms)
|
|
|
|
|
|
+ prediction = map_chr(sms, function(m) { classify(m) })
|
|
|
|
+ )
|
|
|
|
+```
|
|
|
|
|
|
|
|
+# Calculating Accuracy
|
|
|
|
|
|
|
|
+```{r}
|
|
# Results of classification on training
|
|
# Results of classification on training
|
|
-confusion = table(final.train$label, final.train$prediction)
|
|
|
|
-accuracy = (confusion[1,1] + confusion[2,2]) / nrow(final.train)
|
|
|
|
|
|
+confusion <- table(final_train$label, final_train$prediction)
|
|
|
|
+accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(final_train)
|
|
```
|
|
```
|
|
|
|
|
|
-Roughly, the classifier achieves about 97% accuracy on the training set. We aren't interested in how well the classifier performs with training data though, the classifier has already "seen" all of these messages.
|
|
|
|
|
|
+
|
|
|
|
+The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's see how well it works on messages that it has never seen before.
|
|
|
|
|
|
# Hyperparameter Tuning
|
|
# Hyperparameter Tuning
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-alpha.grid = seq(0.1, 1, by = 0.1)
|
|
|
|
-cv.accuracy = NULL
|
|
|
|
|
|
+alpha_grid <- seq(0.05, 1, by = 0.05)
|
|
|
|
+cv_accuracy <- NULL
|
|
|
|
|
|
-for (a in alpha.grid) {
|
|
|
|
|
|
+for (alpha in alpha_grid) {
|
|
|
|
|
|
- spam.probs = list()
|
|
|
|
- ham.probs = list()
|
|
|
|
-
|
|
|
|
- # This might take a while to run with so many words
|
|
|
|
- for (vocab in vocabulary) {
|
|
|
|
-
|
|
|
|
- # Use the counts to calculate the probability
|
|
|
|
- spam.probs[[vocab]] = (spam.counts[[vocab]] + a) / (n.spam + a * n.vocabulary)
|
|
|
|
- ham.probs[[vocab]] = (ham.counts[[vocab]] + a) / (n.ham + a * n.vocabulary)
|
|
|
|
- }
|
|
|
|
|
|
+ # Recalculate probabilities based on new alpha
|
|
|
|
+ cv_probs <- word_counts %>%
|
|
|
|
+ mutate(
|
|
|
|
+ # Calculate the probabilities from the counts based on new alpha
|
|
|
|
+ spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
|
|
|
|
+ ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
|
|
|
|
+ )
|
|
|
|
|
|
- cv = spam.cv %>%
|
|
|
|
|
|
+ # Predict the classification of each message in cross validation
|
|
|
|
+ cv <- spam_cv %>%
|
|
mutate(
|
|
mutate(
|
|
- prediction = unlist(map(sms, classify))
|
|
|
|
- ) %>%
|
|
|
|
- select(label, prediction, sms)
|
|
|
|
|
|
+ prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
|
|
|
|
+ )
|
|
|
|
|
|
- confusion = table(cv$label, cv$prediction)
|
|
|
|
- acc = (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
|
- cv.accuracy = c(cv.accuracy, acc)
|
|
|
|
|
|
+ # Assess the accuracy of the classifier on cross-validation set
|
|
|
|
+ confusion <- table(cv$label, cv$prediction)
|
|
|
|
+ acc <- (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
|
+ cv_accuracy <- c(cv_accuracy, acc)
|
|
}
|
|
}
|
|
|
|
|
|
-cv.check = tibble(
|
|
|
|
- alpha = alpha.grid,
|
|
|
|
- accuracy = cv.accuracy
|
|
|
|
|
|
+# Check out what the best alpha value is
|
|
|
|
+tibble(
|
|
|
|
+ alpha = alpha_grid,
|
|
|
|
+ accuracy = cv_accuracy
|
|
)
|
|
)
|
|
-cv.check
|
|
|
|
```
|
|
```
|
|
|
|
|
|
Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
|
|
Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
|
|
@@ -237,31 +288,18 @@ Judging from the cross-validation set, higher $\alpha$ values cause the accuracy
|
|
# Test Set Performance
|
|
# Test Set Performance
|
|
|
|
|
|
```{r}
|
|
```{r}
|
|
-# Reestablishing the proper parameters
|
|
|
|
-optimal.alpha = 0.1
|
|
|
|
-for (a in alpha.grid) {
|
|
|
|
-
|
|
|
|
- spam.probs = list()
|
|
|
|
- ham.probs = list()
|
|
|
|
-
|
|
|
|
- # This might take a while to run with so many words
|
|
|
|
- for (vocab in vocabulary) {
|
|
|
|
-
|
|
|
|
- # Use the counts to calculate the probability
|
|
|
|
- spam.probs[[vocab]] = (spam.counts[[vocab]] + optimal.alpha) / (n.spam + optimal.alpha * n.vocabulary)
|
|
|
|
- ham.probs[[vocab]] = (ham.counts[[vocab]] + optimal.alpha) / (n.ham + optimal.alpha * n.vocabulary)
|
|
|
|
- }
|
|
|
|
-}
|
|
|
|
|
|
+# Reestablishing the proper parameters
|
|
|
|
+optimal_alpha <- 0.1
|
|
|
|
|
|
-spam.test = spam.test %>%
|
|
|
|
|
|
+# Using optimal alpha with training parameters, perform final predictions
|
|
|
|
+spam_test <- spam_test %>%
|
|
mutate(
|
|
mutate(
|
|
- prediction = unlist(map(sms, classify))
|
|
|
|
- ) %>%
|
|
|
|
- select(label, prediction, sms)
|
|
|
|
|
|
+ prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
|
|
|
|
+ )
|
|
|
|
|
|
-confusion = table(spam.test$label, spam.test$prediction)
|
|
|
|
-test.accuracy = (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
|
-test.accuracy
|
|
|
|
|
|
+confusion <- table(spam_test$label, spam_test$prediction)
|
|
|
|
+test_accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
|
|
|
|
+test_accuracy
|
|
```
|
|
```
|
|
|
|
|
|
We've achieved an accuracy of 93% in the test set. Not bad!
|
|
We've achieved an accuracy of 93% in the test set. Not bad!
|