123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- ---
- title: "Conditional Probability in R: Guided Project Solutions"
- output: html_document
- ---
- ```{r, warning = FALSE, message = FALSE }
- library(tidyverse)
- set.seed(1)
- options(dplyr.summarise.inform = FALSE)
- ```
- # Introduction
- This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
- ```{r}
- # Bring in the dataset
- spam <- read_csv("spam.csv")
- ```
- The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
- # Training, Cross-validation and Test Sets
- ```{r}
- # Calculate some helper values to split the dataset
- n <- nrow(spam)
- n_training <- 0.8 * n
- n_cv <- 0.1 * n
- n_test <- 0.1 * n
- # Create the random indices for training set
- train_indices <- sample(1:n, size = n_training, replace = FALSE)
- # Get indices not used by the training set
- remaining_indices <- setdiff(1:n, train_indices)
- # Remaining indices are already randomized, just allocate correctly
- cv_indices <- remaining_indices[1:(length(remaining_indices)/2)]
- test_indices <- remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
- # Use the indices to create each of the datasets
- spam_train <- spam[train_indices,]
- spam_cv <- spam[cv_indices,]
- spam_test <- spam[test_indices,]
- # Sanity check: are the ratios of ham to spam relatively constant?
- print(mean(spam_train$label == "ham"))
- print(mean(spam_cv$label == "ham"))
- print(mean(spam_test$label == "ham"))
- ```
- The number of ham messages in each dataset is relatively close to each other in each dataset. This is just to make sure that no dataset is entirely just "ham", which ruins the point of spam detection.
- # Data Cleaning
- ```{r}
- # To lowercase, removal of punctuation, weird characters, digits
- tidy_train <- spam_train %>%
- mutate(
- # Take the messages and remove unwanted characters
- sms = str_to_lower(sms) %>%
- str_squish %>%
- str_replace_all("[[:punct:]]", "") %>%
- str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
- str_replace_all("[[:digit:]]", "")
- )
- # Creating the vocabulary
- vocabulary <- NULL
- messages <- tidy_train %>% pull(sms)
- # Iterate through the messages and add to the vocabulary
- for (m in messages) {
- words <- str_split(m, " ")[[1]]
- vocabulary <- c(vocabulary, words)
- }
- # Remove duplicates from the vocabulary
- vocabulary <- vocabulary %>% unique()
- ```
- # Calculating Constants and Parameters
- ```{r}
- # Isolate the spam and ham messages
- spam_messages <- tidy_train %>%
- filter(label == "spam") %>%
- pull(sms)
- ham_messages <- tidy_train %>%
- filter(label == "ham") %>%
- pull(sms)
- # Isolate the vocabulary in spam and ham messages
- spam_vocab <- NULL
- for (sm in spam_messages) {
- words <- str_split(sm, " ")[[1]]
- spam_vocab <- c(spam_vocab, words)
- }
- spam_vocab
- ham_vocab <- NULL
- for (hm in ham_messages) {
- words <- str_split(hm, " ")[[1]]
- ham_vocab <- c(ham_vocab, words)
- }
- ham_vocab
- # Calculate some important parameters from the vocab
- n_spam <- spam_vocab %>% length()
- n_ham <- ham_vocab %>% length()
- n_vocabulary <- vocabulary %>% length()
- ```
- # Calculating Probability Parameters
- ```{r}
- # New vectorized approach to a calculating ham and spam probabilities
- # Marginal probability of a training message being spam or ham
- p_spam <- mean(tidy_train$label == "spam")
- p_ham <- mean(tidy_train$label == "ham")
- # Break up the spam and ham counting into their own tibbles
- spam_counts <- tibble(
- word = spam_vocab
- ) %>%
- mutate(
- # Calculate the number of times a word appears in spam
- spam_count = map_int(word, function(w) {
-
- # Count how many times each word appears in all spam messsages, then sum
- map_int(spam_messages, function(sm) {
- (str_split(sm, " ")[[1]] == w) %>% sum # for a single message
- }) %>%
- sum # then summing over all messages
-
- })
- )
- # There are many words in the ham vocabulary so this will take a while!
- # Run this code and distract yourself while the counts are calculated
- ham_counts <- tibble(
- word = ham_vocab
- ) %>%
- mutate(
- # Calculate the number of times a word appears in ham
- ham_count = map_int(word, function(w) {
-
- # Count how many times each word appears in all ham messsages, then sum
- map_int(ham_messages, function(hm) {
- (str_split(hm, " ")[[1]] == w) %>% sum
- }) %>%
- sum
-
- })
- )
- # Join these tibbles together
- word_counts <- full_join(spam_counts, ham_counts, by = "word") %>%
- mutate(
- # Fill in zeroes where there are missing values
- spam_count = ifelse(is.na(spam_count), 0, spam_count),
- ham_count = ifelse(is.na(ham_count), 0, ham_count)
- )
- ```
- # Classifying New Messages
- ```{r}
- # This is the updated function using the vectorized approach to calculating
- # the spam and ham probabilities
- # Create a function that makes it easy to classify a tibble of messages
- # we add an alpha argument to make it easy to recalculate probabilities
- # based on this alpha (default to 1)
- classify <- function(message, alpha = 1) {
-
- # Splitting and cleaning the new message
- # This is the same cleaning procedure used on the training messages
- clean_message <- str_to_lower(message) %>%
- str_squish %>%
- str_replace_all("[[:punct:]]", "") %>%
- str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
- str_replace_all("[[:digit:]]", "")
-
- words <- str_split(clean_message, " ")[[1]]
-
- # There is a possibility that there will be words that don't appear
- # in the training vocabulary, so this must be accounted for
-
- # Find the words that aren't present in the training
- new_words <- setdiff(vocabulary, words)
-
- # Add them to the word_counts
- new_word_probs <- tibble(
- word = new_words,
- spam_prob = 1,
- ham_prob = 1
- )
- # Filter down the probabilities to the words present
- # use group by to multiply everything together
- present_probs <- word_counts %>%
- filter(word %in% words) %>%
- mutate(
- # Calculate the probabilities from the counts
- spam_prob = (spam_count + alpha) / (n_spam + alpha * n_vocabulary),
- ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
- ) %>%
- bind_rows(new_word_probs) %>%
- pivot_longer(
- cols = c("spam_prob", "ham_prob"),
- names_to = "label",
- values_to = "prob"
- ) %>%
- group_by(label) %>%
- summarize(
- wi_prob = prod(prob) # prod is like sum, but with multiplication
- )
-
- # Calculate the conditional probabilities
- p_spam_given_message <- p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
- p_ham_given_message <- p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
-
- # Classify the message based on the probability
- ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
- }
- # Use the classify function to classify the messages in the training set
- # This takes advantage of vectorization
- final_train <- tidy_train %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m) })
- )
- ```
- # Calculating Accuracy
- ```{r}
- # Results of classification on training
- confusion <- table(final_train$label, final_train$prediction)
- accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(final_train)
- ```
- The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's see how well it works on messages that it has never seen before.
- # Hyperparameter Tuning
- ```{r}
- alpha_grid <- seq(0.05, 1, by = 0.05)
- cv_accuracy <- NULL
- for (alpha in alpha_grid) {
-
- # Recalculate probabilities based on new alpha
- cv_probs <- word_counts %>%
- mutate(
- # Calculate the probabilities from the counts based on new alpha
- spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
- ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
- )
-
- # Predict the classification of each message in cross validation
- cv <- spam_cv %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
- )
-
- # Assess the accuracy of the classifier on cross-validation set
- confusion <- table(cv$label, cv$prediction)
- acc <- (confusion[1,1] + confusion[2,2]) / nrow(cv)
- cv_accuracy <- c(cv_accuracy, acc)
- }
- # Check out what the best alpha value is
- tibble(
- alpha = alpha_grid,
- accuracy = cv_accuracy
- )
- ```
- Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
- # Test Set Performance
- ```{r}
- # Reestablishing the proper parameters
- optimal_alpha <- 0.1
- # Using optimal alpha with training parameters, perform final predictions
- spam_test <- spam_test %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
- )
-
- confusion <- table(spam_test$label, spam_test$prediction)
- test_accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
- test_accuracy
- ```
- We've achieved an accuracy of 93% in the test set. Not bad!
|