123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- ---
- title: "Conditional Probability in R: Guided Project Solutions"
- output: html_document
- ---
- ```{r, warning = FALSE, message = FALSE }
- library(tidyverse)
- set.seed(1)
- options(dplyr.summarise.inform = FALSE)
- ```
- This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
- ```{r}
- spam <- read_csv("spam.csv")
- ```
- The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
- ```{r}
- n <- nrow(spam)
- n_training <- 0.8 * n
- n_cv <- 0.1 * n
- n_test <- 0.1 * n
- train_indices <- sample(1:n, size = n_training, replace = FALSE)
- remaining_indices <- setdiff(1:n, train_indices)
- cv_indices <- remaining_indices[1:(length(remaining_indices)/2)]
- test_indices <- remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
- spam_train <- spam[train_indices,]
- spam_cv <- spam[cv_indices,]
- spam_test <- spam[test_indices,]
- print(mean(spam_train$label == "ham"))
- print(mean(spam_cv$label == "ham"))
- print(mean(spam_test$label == "ham"))
- ```
- The number of ham messages in each dataset is relatively close to each other in each dataset. This is just to make sure that no dataset is entirely just "ham", which ruins the point of spam detection.
- ```{r}
- tidy_train <- spam_train %>%
- mutate(
- sms = str_to_lower(sms) %>%
- str_squish %>%
- str_replace_all("[[:punct:]]", "") %>%
- str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
- str_replace_all("[[:digit:]]", "")
- )
- vocabulary <- NULL
- messages <- tidy_train %>% pull(sms)
- for (m in messages) {
- words <- str_split(m, " ")[[1]]
- vocabulary <- c(vocabulary, words)
- }
- vocabulary <- vocabulary %>% unique()
- ```
- ```{r}
- spam_messages <- tidy_train %>%
- filter(label == "spam") %>%
- pull(sms)
- ham_messages <- tidy_train %>%
- filter(label == "ham") %>%
- pull(sms)
- spam_vocab <- NULL
- for (sm in spam_messages) {
- words <- str_split(sm, " ")[[1]]
- spam_vocab <- c(spam_vocab, words)
- }
- spam_vocab <- spam_vocab %>% unique
- ham_vocab <- NULL
- for (hm in ham_messages) {
- words <- str_split(hm, " ")[[1]]
- ham_vocab <- c(ham_vocab, words)
- }
- ham_vocab <- ham_vocab %>% unique()
- n_spam <- spam_vocab %>% length()
- n_ham <- ham_vocab %>% length()
- n_vocabulary <- vocabulary %>% length()
- ```
- ```{r}
- p_spam <- mean(tidy_train$label == "spam")
- p_ham <- mean(tidy_train$label == "ham")
- spam_counts <- tibble(
- word = spam_vocab
- ) %>%
- mutate(
- spam_count = map_int(word, function(w) {
- map_int(spam_messages, function(sm) {
- (str_split(sm, " ")[[1]] == w) %>% sum # for a single message
- }) %>%
- sum # then summing over all messages
-
- })
- )
- ham_counts <- tibble(
- word = ham_vocab
- ) %>%
- mutate(
- ham_count = map_int(word, function(w) {
- map_int(ham_messages, function(hm) {
- (str_split(hm, " ")[[1]] == w) %>% sum
- }) %>%
- sum
-
- })
- )
- word_counts <- full_join(spam_counts, ham_counts, by = "word") %>%
- mutate(
- spam_count = ifelse(is.na(spam_count), 0, spam_count),
- ham_count = ifelse(is.na(ham_count), 0, ham_count)
- )
- ```
- ```{r}
- classify <- function(message, alpha = 1) {
- clean_message <- str_to_lower(message) %>%
- str_squish %>%
- str_replace_all("[[:punct:]]", "") %>%
- str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
- str_replace_all("[[:digit:]]", "")
-
- words <- str_split(clean_message, " ")[[1]]
- new_words <- setdiff(vocabulary, words)
- new_word_probs <- tibble(
- word = new_words,
- spam_prob = 1,
- ham_prob = 1
- )
- present_probs <- word_counts %>%
- filter(word %in% words) %>%
- mutate(
- spam_prob = (spam_count + alpha) / (n_spam + alpha * n_vocabulary),
- ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
- ) %>%
- bind_rows(new_word_probs) %>%
- pivot_longer(
- cols = c("spam_prob", "ham_prob"),
- names_to = "label",
- values_to = "prob"
- ) %>%
- group_by(label) %>%
- summarize(
- wi_prob = prod(prob) # prod is like sum, but with multiplication
- )
- p_spam_given_message <- p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
- p_ham_given_message <- p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
- ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
- }
- final_train <- tidy_train %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m) })
- )
- ```
- ```{r}
- confusion <- table(final_train$label, final_train$prediction)
- accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(final_train)
- ```
- The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's see how well it works on messages that it has never seen before.
- ```{r}
- alpha_grid <- seq(0.05, 1, by = 0.05)
- cv_accuracy <- NULL
- for (alpha in alpha_grid) {
- cv_probs <- word_counts %>%
- mutate(
- spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
- ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
- )
- cv <- spam_cv %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
- )
- confusion <- table(cv$label, cv$prediction)
- acc <- (confusion[1,1] + confusion[2,2]) / nrow(cv)
- cv_accuracy <- c(cv_accuracy, acc)
- }
- tibble(
- alpha = alpha_grid,
- accuracy = cv_accuracy
- )
- ```
- Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
- ```{r}
- optimal_alpha <- 0.1
- spam_test <- spam_test %>%
- mutate(
- prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
- )
-
- confusion <- table(spam_test$label, spam_test$prediction)
- test_accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
- test_accuracy
- ```
- We've achieved an accuracy of 93% in the test set. Not bad!
|