Queer European MD passionate about IT

Mission475Solutions.Rmd 9.0 KB


  1. ---
  2. title: "Conditional Probability in R: Guided Project Solutions"
  3. output: html_document
  4. ---
  5. ```{r, warning = FALSE, message = FALSE }
  6. library(tidyverse)
  7. set.seed(1)
  8. options(dplyr.summarise.inform = FALSE)
  9. ```
  10. # Introduction
  11. This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
  12. ```{r}
  13. # Bring in the dataset
  14. spam <- read_csv("spam.csv")
  15. ```
  16. The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
  17. # Training, Cross-validation and Test Sets
  18. ```{r}
  19. # Calculate some helper values to split the dataset
  20. n <- nrow(spam)
  21. n_training <- 0.8 * n
  22. n_cv <- 0.1 * n
  23. n_test <- 0.1 * n
  24. # Create the random indices for training set
  25. train_indices <- sample(1:n, size = n_training, replace = FALSE)
  26. # Get indices not used by the training set
  27. remaining_indices <- setdiff(1:n, train_indices)
  28. # Remaining indices are already randomized, just allocate correctly
  29. cv_indices <- remaining_indices[1:(length(remaining_indices)/2)]
  30. test_indices <- remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
  31. # Use the indices to create each of the datasets
  32. spam_train <- spam[train_indices,]
  33. spam_cv <- spam[cv_indices,]
  34. spam_test <- spam[test_indices,]
  35. # Sanity check: are the ratios of ham to spam relatively constant?
  36. print(mean(spam_train$label == "ham"))
  37. print(mean(spam_cv$label == "ham"))
  38. print(mean(spam_test$label == "ham"))
  39. ```
  40. The number of ham messages in each dataset is relatively close to each other in each dataset. This is just to make sure that no dataset is entirely just "ham", which ruins the point of spam detection.
  41. # Data Cleaning
  42. ```{r}
  43. # To lowercase, removal of punctuation, weird characters, digits
  44. tidy_train <- spam_train %>%
  45. mutate(
  46. # Take the messages and remove unwanted characters
  47. sms = str_to_lower(sms) %>%
  48. str_squish %>%
  49. str_replace_all("[[:punct:]]", "") %>%
  50. str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
  51. str_replace_all("[[:digit:]]", "")
  52. )
  53. # Creating the vocabulary
  54. vocabulary <- NULL
  55. messages <- tidy_train %>% pull(sms)
  56. # Iterate through the messages and add to the vocabulary
  57. for (m in messages) {
  58. words <- str_split(m, " ")[[1]]
  59. vocabulary <- c(vocabulary, words)
  60. }
  61. # Remove duplicates from the vocabulary
  62. vocabulary <- vocabulary %>% unique()
  63. ```
  64. # Calculating Constants and Parameters
  65. ```{r}
  66. # Isolate the spam and ham messages
  67. spam_messages <- tidy_train %>%
  68. filter(label == "spam") %>%
  69. pull(sms)
  70. ham_messages <- tidy_train %>%
  71. filter(label == "ham") %>%
  72. pull(sms)
  73. # Isolate the vocabulary in spam and ham messages
  74. spam_vocab <- NULL
  75. for (sm in spam_messages) {
  76. words <- str_split(sm, " ")[[1]]
  77. spam_vocab <- c(spam_vocab, words)
  78. }
  79. spam_vocab <- spam_vocab %>% unique
  80. ham_vocab <- NULL
  81. for (hm in ham_messages) {
  82. words <- str_split(hm, " ")[[1]]
  83. ham_vocab <- c(ham_vocab, words)
  84. }
  85. ham_vocab <- ham_vocab %>% unique()
  86. # Calculate some important parameters from the vocab
  87. n_spam <- spam_vocab %>% length()
  88. n_ham <- ham_vocab %>% length()
  89. n_vocabulary <- vocabulary %>% length()
  90. ```
  91. # Calculating Probability Parameters
  92. ```{r}
  93. # New vectorized approach to a calculating ham and spam probabilities
  94. # Marginal probability of a training message being spam or ham
  95. p_spam <- mean(tidy_train$label == "spam")
  96. p_ham <- mean(tidy_train$label == "ham")
  97. # Break up the spam and ham counting into their own tibbles
  98. spam_counts <- tibble(
  99. word = spam_vocab
  100. ) %>%
  101. mutate(
  102. # Calculate the number of times a word appears in spam
  103. spam_count = map_int(word, function(w) {
  104. # Count how many times each word appears in all spam messsages, then sum
  105. map_int(spam_messages, function(sm) {
  106. (str_split(sm, " ")[[1]] == w) %>% sum # for a single message
  107. }) %>%
  108. sum # then summing over all messages
  109. })
  110. )
  111. # There are many words in the ham vocabulary so this will take a while!
  112. # Run this code and distract yourself while the counts are calculated
  113. ham_counts <- tibble(
  114. word = ham_vocab
  115. ) %>%
  116. mutate(
  117. # Calculate the number of times a word appears in ham
  118. ham_count = map_int(word, function(w) {
  119. # Count how many times each word appears in all ham messsages, then sum
  120. map_int(ham_messages, function(hm) {
  121. (str_split(hm, " ")[[1]] == w) %>% sum
  122. }) %>%
  123. sum
  124. })
  125. )
  126. # Join these tibbles together
  127. word_counts <- full_join(spam_counts, ham_counts, by = "word") %>%
  128. mutate(
  129. # Fill in zeroes where there are missing values
  130. spam_count = ifelse(is.na(spam_count), 0, spam_count),
  131. ham_count = ifelse(is.na(ham_count), 0, ham_count)
  132. )
  133. ```
  134. # Classifying New Messages
  135. ```{r}
  136. # This is the updated function using the vectorized approach to calculating
  137. # the spam and ham probabilities
  138. # Create a function that makes it easy to classify a tibble of messages
  139. # we add an alpha argument to make it easy to recalculate probabilities
  140. # based on this alpha (default to 1)
  141. classify <- function(message, alpha = 1) {
  142. # Splitting and cleaning the new message
  143. # This is the same cleaning procedure used on the training messages
  144. clean_message <- str_to_lower(message) %>%
  145. str_squish %>%
  146. str_replace_all("[[:punct:]]", "") %>%
  147. str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
  148. str_replace_all("[[:digit:]]", "")
  149. words <- str_split(clean_message, " ")[[1]]
  150. # There is a possibility that there will be words that don't appear
  151. # in the training vocabulary, so this must be accounted for
  152. # Find the words that aren't present in the training
  153. new_words <- setdiff(vocabulary, words)
  154. # Add them to the word_counts
  155. new_word_probs <- tibble(
  156. word = new_words,
  157. spam_prob = 1,
  158. ham_prob = 1
  159. )
  160. # Filter down the probabilities to the words present
  161. # use group by to multiply everything together
  162. present_probs <- word_counts %>%
  163. filter(word %in% words) %>%
  164. mutate(
  165. # Calculate the probabilities from the counts
  166. spam_prob = (spam_count + alpha) / (n_spam + alpha * n_vocabulary),
  167. ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
  168. ) %>%
  169. bind_rows(new_word_probs) %>%
  170. pivot_longer(
  171. cols = c("spam_prob", "ham_prob"),
  172. names_to = "label",
  173. values_to = "prob"
  174. ) %>%
  175. group_by(label) %>%
  176. summarize(
  177. wi_prob = prod(prob) # prod is like sum, but with multiplication
  178. )
  179. # Calculate the conditional probabilities
  180. p_spam_given_message <- p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
  181. p_ham_given_message <- p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
  182. # Classify the message based on the probability
  183. ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
  184. }
  185. # Use the classify function to classify the messages in the training set
  186. # This takes advantage of vectorization
  187. final_train <- tidy_train %>%
  188. mutate(
  189. prediction = map_chr(sms, function(m) { classify(m) })
  190. )
  191. ```
  192. # Calculating Accuracy
  193. ```{r}
  194. # Results of classification on training
  195. confusion <- table(final_train$label, final_train$prediction)
  196. accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(final_train)
  197. ```
  198. The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's see how well it works on messages that it has never seen before.
  199. # Hyperparameter Tuning
  200. ```{r}
  201. alpha_grid <- seq(0.05, 1, by = 0.05)
  202. cv_accuracy <- NULL
  203. for (alpha in alpha_grid) {
  204. # Recalculate probabilities based on new alpha
  205. cv_probs <- word_counts %>%
  206. mutate(
  207. # Calculate the probabilities from the counts based on new alpha
  208. spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
  209. ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
  210. )
  211. # Predict the classification of each message in cross validation
  212. cv <- spam_cv %>%
  213. mutate(
  214. prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
  215. )
  216. # Assess the accuracy of the classifier on cross-validation set
  217. confusion <- table(cv$label, cv$prediction)
  218. acc <- (confusion[1,1] + confusion[2,2]) / nrow(cv)
  219. cv_accuracy <- c(cv_accuracy, acc)
  220. }
  221. # Check out what the best alpha value is
  222. tibble(
  223. alpha = alpha_grid,
  224. accuracy = cv_accuracy
  225. )
  226. ```
  227. Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
  228. # Test Set Performance
  229. ```{r}
  230. # Reestablishing the proper parameters
  231. optimal_alpha <- 0.1
  232. # Using optimal alpha with training parameters, perform final predictions
  233. spam_test <- spam_test %>%
  234. mutate(
  235. prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
  236. )
  237. confusion <- table(spam_test$label, spam_test$prediction)
  238. test_accuracy <- (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
  239. test_accuracy
  240. ```
  241. We've achieved an accuracy of 93% in the test set. Not bad!