|
@@ -15,7 +15,7 @@ This analysis is an application of what we've learned in Dataquest's Conditional
|
|
|
|
|
|
```{r}
|
|
|
# Bring in the dataset
|
|
|
-spam = read_csv("spam.csv")
|
|
|
+spam <- read_csv("spam.csv")
|
|
|
```
|
|
|
|
|
|
The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
|
|
@@ -24,25 +24,25 @@ The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these
|
|
|
|
|
|
```{r}
|
|
|
# Calculate some helper values to split the dataset
|
|
|
-n = nrow(spam)
|
|
|
-n_training = 0.8 * n
|
|
|
-n_cv = 0.1 * n
|
|
|
-n_test = 0.1 * n
|
|
|
+n <- nrow(spam)
|
|
|
+n_training <- 0.8 * n
|
|
|
+n_cv <- 0.1 * n
|
|
|
+n_test <- 0.1 * n
|
|
|
|
|
|
# Create the random indices for training set
|
|
|
-train_indices = sample(1:n, size = n_training, replace = FALSE)
|
|
|
+train_indices <- sample(1:n, size = n_training, replace = FALSE)
|
|
|
|
|
|
# Get indices not used by the training set
|
|
|
-remaining_indices = setdiff(1:n, train_indices)
|
|
|
+remaining_indices <- setdiff(1:n, train_indices)
|
|
|
|
|
|
# Remaining indices are already randomized, just allocate correctly
|
|
|
-cv_indices = remaining_indices[1:(length(remaining_indices)/2)]
|
|
|
-test_indices = remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
|
|
|
+cv_indices <- remaining_indices[1:(length(remaining_indices)/2)]
|
|
|
+test_indices <- remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
|
|
|
|
|
|
# Use the indices to create each of the datasets
|
|
|
-spam_train = spam[train_indices,]
|
|
|
-spam_cv = spam[cv_indices,]
|
|
|
-spam_test = spam[test_indices,]
|
|
|
+spam_train <- spam[train_indices,]
|
|
|
+spam_cv <- spam[cv_indices,]
|
|
|
+spam_test <- spam[test_indices,]
|
|
|
|
|
|
# Sanity check: are the ratios of ham to spam relatively constant?
|
|
|
print(mean(spam_train$label == "ham"))
|
|
@@ -56,7 +56,7 @@ The number of ham messages in each dataset is relatively close to each other in
|
|
|
|
|
|
```{r}
|
|
|
# To lowercase, removal of punctuation, weird characters, digits
|
|
|
-tidy_train = spam_train %>%
|
|
|
+tidy_train <- spam_train %>%
|
|
|
mutate(
|
|
|
# Take the messages and remove unwanted characters
|
|
|
sms = str_to_lower(sms) %>%
|
|
@@ -67,50 +67,50 @@ tidy_train = spam_train %>%
|
|
|
)
|
|
|
|
|
|
# Creating the vocabulary
|
|
|
-vocabulary = NULL
|
|
|
-messages = tidy_train %>% pull(sms)
|
|
|
+vocabulary <- NULL
|
|
|
+messages <- tidy_train %>% pull(sms)
|
|
|
|
|
|
# Iterate through the messages and add to the vocabulary
|
|
|
for (m in messages) {
|
|
|
- words = str_split(m, " ")[[1]]
|
|
|
- vocabulary = c(vocabulary, words)
|
|
|
+ words <- str_split(m, " ")[[1]]
|
|
|
+ vocabulary <- c(vocabulary, words)
|
|
|
}
|
|
|
|
|
|
# Remove duplicates from the vocabulary
|
|
|
-vocabulary = vocabulary %>% unique
|
|
|
+vocabulary <- vocabulary %>% unique()
|
|
|
```
|
|
|
|
|
|
# Calculating Constants and Parameters
|
|
|
|
|
|
```{r}
|
|
|
# Isolate the spam and ham messages
|
|
|
-spam_messages = tidy_train %>%
|
|
|
+spam_messages <- tidy_train %>%
|
|
|
filter(label == "spam") %>%
|
|
|
pull(sms)
|
|
|
|
|
|
-ham_messages = tidy_train %>%
|
|
|
+ham_messages <- tidy_train %>%
|
|
|
filter(label == "ham") %>%
|
|
|
pull(sms)
|
|
|
|
|
|
# Isolate the vocabulary in spam and ham messages
|
|
|
-spam_vocab = NULL
|
|
|
+spam_vocab <- NULL
|
|
|
for (sm in spam_messages) {
|
|
|
- words = str_split(sm, " ")[[1]]
|
|
|
- spam_vocab = c(spam_vocab, words)
|
|
|
+ words <- str_split(sm, " ")[[1]]
|
|
|
+ spam_vocab <- c(spam_vocab, words)
|
|
|
}
|
|
|
-spam_vocab = spam_vocab %>% unique
|
|
|
+spam_vocab <- spam_vocab %>% unique
|
|
|
|
|
|
-ham_vocab = NULL
|
|
|
+ham_vocab <- NULL
|
|
|
for (hm in ham_messages) {
|
|
|
- words = str_split(hm, " ")[[1]]
|
|
|
- ham_vocab = c(ham_vocab, words)
|
|
|
+ words <- str_split(hm, " ")[[1]]
|
|
|
+ ham_vocab <- c(ham_vocab, words)
|
|
|
}
|
|
|
-ham_vocab = ham_vocab %>% unique
|
|
|
+ham_vocab <- ham_vocab %>% unique()
|
|
|
|
|
|
# Calculate some important parameters from the vocab
|
|
|
-n_spam = spam_vocab %>% length
|
|
|
-n_ham = ham_vocab %>% length
|
|
|
-n_vocabulary = vocabulary %>% length
|
|
|
+n_spam <- spam_vocab %>% length()
|
|
|
+n_ham <- ham_vocab %>% length()
|
|
|
+n_vocabulary <- vocabulary %>% length()
|
|
|
```
|
|
|
|
|
|
# Calculating Probability Parameters
|
|
@@ -119,11 +119,11 @@ n_vocabulary = vocabulary %>% length
|
|
|
# New vectorized approach to a calculating ham and spam probabilities
|
|
|
|
|
|
# Marginal probability of a training message being spam or ham
|
|
|
-p_spam = mean(tidy_train$label == "spam")
|
|
|
-p_ham = mean(tidy_train$label == "ham")
|
|
|
+p_spam <- mean(tidy_train$label == "spam")
|
|
|
+p_ham <- mean(tidy_train$label == "ham")
|
|
|
|
|
|
# Break up the spam and ham counting into their own tibbles
|
|
|
-spam_counts = tibble(
|
|
|
+spam_counts <- tibble(
|
|
|
word = spam_vocab
|
|
|
) %>%
|
|
|
mutate(
|
|
@@ -141,7 +141,7 @@ spam_counts = tibble(
|
|
|
|
|
|
# There are many words in the ham vocabulary so this will take a while!
|
|
|
# Run this code and distract yourself while the counts are calculated
|
|
|
-ham_counts = tibble(
|
|
|
+ham_counts <- tibble(
|
|
|
word = ham_vocab
|
|
|
) %>%
|
|
|
mutate(
|
|
@@ -158,7 +158,7 @@ ham_counts = tibble(
|
|
|
)
|
|
|
|
|
|
# Join these tibbles together
|
|
|
-word_counts = full_join(spam_counts, ham_counts, by = "word") %>%
|
|
|
+word_counts <- full_join(spam_counts, ham_counts, by = "word") %>%
|
|
|
mutate(
|
|
|
# Fill in zeroes where there are missing values
|
|
|
spam_count = ifelse(is.na(spam_count), 0, spam_count),
|
|
@@ -176,26 +176,26 @@ word_counts = full_join(spam_counts, ham_counts, by = "word") %>%
|
|
|
# Create a function that makes it easy to classify a tibble of messages
|
|
|
# we add an alpha argument to make it easy to recalculate probabilities
|
|
|
# based on this alpha (default to 1)
|
|
|
-classify = function(message, alpha = 1) {
|
|
|
+classify <- function(message, alpha = 1) {
|
|
|
|
|
|
# Splitting and cleaning the new message
|
|
|
# This is the same cleaning procedure used on the training messages
|
|
|
- clean_message = str_to_lower(message) %>%
|
|
|
+ clean_message <- str_to_lower(message) %>%
|
|
|
str_squish %>%
|
|
|
str_replace_all("[[:punct:]]", "") %>%
|
|
|
str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
|
|
|
str_replace_all("[[:digit:]]", "")
|
|
|
|
|
|
- words = str_split(clean_message, " ")[[1]]
|
|
|
+ words <- str_split(clean_message, " ")[[1]]
|
|
|
|
|
|
# There is a possibility that there will be words that don't appear
|
|
|
# in the training vocabulary, so this must be accounted for
|
|
|
|
|
|
# Find the words that aren't present in the training
|
|
|
- new_words = setdiff(vocabulary, words)
|
|
|
+ new_words <- setdiff(vocabulary, words)
|
|
|
|
|
|
# Add them to the word_counts
|
|
|
- new_word_probs = tibble(
|
|
|
+ new_word_probs <- tibble(
|
|
|
word = new_words,
|
|
|
spam_prob = 1,
|
|
|
ham_prob = 1
|
|
@@ -203,7 +203,7 @@ classify = function(message, alpha = 1) {
|
|
|
|
|
|
# Filter down the probabilities to the words present
|
|
|
# use group by to multiply everything together
|
|
|
- present_probs = word_counts %>%
|
|
|
+ present_probs <- word_counts %>%
|
|
|
filter(word %in% words) %>%
|
|
|
mutate(
|
|
|
# Calculate the probabilities from the counts
|
|
@@ -222,8 +222,8 @@ classify = function(message, alpha = 1) {
|
|
|
)
|
|
|
|
|
|
# Calculate the conditional probabilities
|
|
|
- p_spam_given_message = p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
|
|
|
- p_ham_given_message = p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
|
|
|
+ p_spam_given_message <- p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
|
|
|
+ p_ham_given_message <- p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
|
|
|
|
|
|
# Classify the message based on the probability
|
|
|
ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
|
|
@@ -231,7 +231,7 @@ classify = function(message, alpha = 1) {
|
|
|
|
|
|
# Use the classify function to classify the messages in the training set
|
|
|
# This takes advantage of vectorization
|
|
|
-final_train = tidy_train %>%
|
|
|
+final_train <- tidy_train %>%
|
|
|
mutate(
|
|
|
prediction = map_chr(sms, function(m) { classify(m) })
|
|
|
)
|
|
@@ -241,8 +241,8 @@ final_train = tidy_train %>%
|
|
|
|
|
|
```{r}
|
|
|
# Results of classification on training
|
|
|
-confusion = table(final_train$label, final_train$prediction)
|
|
|
-accuracy = (confusion[1,1] + confusion[2,2]) / nrow(final_train)
|
|
|
+confusion <- table(final_train$label, final_train$prediction)
|
|
|
+accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(final_train)
|
|
|
```
|
|
|
|
|
|
|
|
@@ -251,13 +251,13 @@ The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's
|
|
|
# Hyperparameter Tuning
|
|
|
|
|
|
```{r}
|
|
|
-alpha_grid = seq(0.05, 1, by = 0.05)
|
|
|
-cv_accuracy = NULL
|
|
|
+alpha_grid <- seq(0.05, 1, by = 0.05)
|
|
|
+cv_accuracy <- NULL
|
|
|
|
|
|
for (alpha in alpha_grid) {
|
|
|
|
|
|
# Recalculate probabilities based on new alpha
|
|
|
- cv_probs = word_counts %>%
|
|
|
+ cv_probs <- word_counts %>%
|
|
|
mutate(
|
|
|
# Calculate the probabilities from the counts based on new alpha
|
|
|
spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
|
|
@@ -265,15 +265,15 @@ for (alpha in alpha_grid) {
|
|
|
)
|
|
|
|
|
|
# Predict the classification of each message in cross validation
|
|
|
- cv = spam_cv %>%
|
|
|
+ cv <- spam_cv %>%
|
|
|
mutate(
|
|
|
prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
|
|
|
)
|
|
|
|
|
|
# Assess the accuracy of the classifier on cross-validation set
|
|
|
- confusion = table(cv$label, cv$prediction)
|
|
|
- acc = (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
- cv_accuracy = c(cv_accuracy, acc)
|
|
|
+ confusion <- table(cv$label, cv$prediction)
|
|
|
+ acc <- (confusion[1,1] + confusion[2,2]) / nrow(cv)
|
|
|
+ cv_accuracy <- c(cv_accuracy, acc)
|
|
|
}
|
|
|
|
|
|
# Check out what the best alpha value is
|
|
@@ -289,16 +289,16 @@ Judging from the cross-validation set, higher $\alpha$ values cause the accuracy
|
|
|
|
|
|
```{r}
|
|
|
# Reestablishing the proper parameters
|
|
|
-optimal_alpha = 0.1
|
|
|
+optimal_alpha <- 0.1
|
|
|
|
|
|
# Using optimal alpha with training parameters, perform final predictions
|
|
|
-spam_test = spam_test %>%
|
|
|
+spam_test <- spam_test %>%
|
|
|
mutate(
|
|
|
prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
|
|
|
)
|
|
|
|
|
|
-confusion = table(spam_test$label, spam_test$prediction)
|
|
|
-test_accuracy = (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
|
|
|
+confusion <- table(spam_test$label, spam_test$prediction)
|
|
|
+test_accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
|
|
|
test_accuracy
|
|
|
```
|
|
|
|