Queer European MD passionate about IT

Mission475Solutions.Rmd 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. ---
  2. title: "Conditional Probability in R: Guided Project Solutions"
  3. output: html_document
  4. ---
  5. ```{r, warning = FALSE, message = FALSE }
  6. library(tidyverse)
  7. set.seed(1)
  8. options(dplyr.summarise.inform = FALSE)
  9. ```
  10. # Introduction
  11. This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
  12. ```{r}
  13. # Bring in the dataset
  14. spam = read.csv("spam.csv")
  15. ```
  16. The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
  17. # Dividing Up Into Training and Test Sets
  18. ```{r}
  19. # Calculate some helper values to split the dataset
  20. n = nrow(spam)
  21. n_training = 0.8 * n
  22. n_cv = 0.1 * n
  23. n_test = 0.1 * n
  24. # Create the random indices for training set
  25. train_indices = sample(1:n, size = n_training, replace = FALSE)
  26. # Get indices not used by the training set
  27. remaining_indices = setdiff(1:n, train_indices)
  28. # Remaining indices are already randomized, just allocate correctly
  29. cv_indices = remaining_indices[1:(length(remaining_indices)/2)]
  30. test_indices = remaining_indices[((length(remaining_indices)/2) + 1):length(remaining_indices)]
  31. # Use the indices to create each of the datasets
  32. spam_train = spam[train_indices,]
  33. spam_cv = spam[cv_indices,]
  34. spam_test = spam[test_indices,]
  35. # Sanity check: are the ratios of ham to spam relatively constant?
  36. print(mean(spam_train$label == "ham"))
  37. print(mean(spam_cv$label == "ham"))
  38. print(mean(spam_test$label == "ham"))
  39. ```
  40. The number of ham messages in each dataset is relatively close to each other in each dataset. This is just to make sure that no dataset is entirely just "ham", which ruins the point of spam detection.
  41. # Data Cleaning
  42. ```{r}
  43. # To lowercase, removal of punctuation, weird characters, digits
  44. tidy_train = spam_train %>%
  45. mutate(
  46. # Take the messages and remove unwanted characters
  47. sms = str_to_lower(sms) %>%
  48. str_squish %>%
  49. str_replace_all("[[:punct:]]", "") %>%
  50. str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
  51. str_replace_all("[[:digit:]]", "")
  52. )
  53. # Creating the vocabulary
  54. vocabulary = NULL
  55. messages = tidy_train %>% pull(sms)
  56. # Iterate through the messages and add to the vocabulary
  57. for (m in messages) {
  58. words = str_split(m, " ")[[1]]
  59. vocabulary = c(vocabulary, words)
  60. }
  61. # Remove duplicates from the vocabulary
  62. vocabulary = vocabulary %>% unique
  63. ```
  64. # Calculating Constants and Parameters
  65. ```{r}
  66. # Isolate the spam and ham messages
  67. spam_messages = tidy_train %>%
  68. filter(label == "spam") %>%
  69. pull(sms)
  70. ham_messages = tidy_train %>%
  71. filter(label == "ham") %>%
  72. pull(sms)
  73. # Isolate the vocabulary in spam and ham messages
  74. spam_vocab = NULL
  75. for (sm in spam_messages) {
  76. words = str_split(sm, " ")[[1]]
  77. spam_vocab = c(spam_vocab, words)
  78. }
  79. spam_vocab = spam_vocab %>% unique
  80. ham_vocab = NULL
  81. for (hm in ham_messages) {
  82. words = str_split(hm, " ")[[1]]
  83. ham_vocab = c(ham_vocab, words)
  84. }
  85. ham_vocab = ham_vocab %>% unique
  86. # Calculate some important parameters from the vocab
  87. n_spam = spam_vocab %>% length
  88. n_ham = ham_vocab %>% length
  89. n_vocabulary = vocabulary %>% length
  90. ```
  91. ```{r}
  92. # New vectorized approach to a calculating ham and spam probabilities
  93. # Break up the spam and ham counting into their own tibbles
  94. spam_counts = tibble(
  95. word = spam_vocab
  96. ) %>%
  97. mutate(
  98. # Calculate the number of times a word appears in spam
  99. spam_count = map_int(word, function(w) {
  100. # Count how many times each word appears in all spam messsages, then sum
  101. map_int(spam_messages, function(sm) {
  102. (str_split(sm, " ")[[1]] == w) %>% sum # for a single message
  103. }) %>%
  104. sum # then summing over all messages
  105. })
  106. )
  107. # There are many words in the ham vocabulary so this will take a while!
  108. # Run this code and distract yourself while the counts are calculated
  109. ham_counts = tibble(
  110. word = ham_vocab
  111. ) %>%
  112. mutate(
  113. # Calculate the number of times a word appears in ham
  114. ham_count = map_int(word, function(w) {
  115. # Count how many times each word appears in all ham messsages, then sum
  116. map_int(ham_messages, function(hm) {
  117. (str_split(hm, " ")[[1]] == w) %>% sum
  118. }) %>%
  119. sum
  120. })
  121. )
  122. # Join these tibbles together
  123. word_counts = full_join(spam_counts, ham_counts, by = "word") %>%
  124. mutate(
  125. # Fill in zeroes where there are missing values
  126. spam_count = ifelse(is.na(spam_count), 0, spam_count),
  127. ham_count = ifelse(is.na(ham_count), 0, ham_count)
  128. )
  129. ```
  130. # Classifying New Messages
  131. ```{r}
  132. # This is the updated function using the vectorized approach to calculating
  133. # the spam and ham probabilities
  134. # Create a function that makes it easy to classify a tibble of messages
  135. # we add an alpha argument to make it easy to recalculate probabilities
  136. # based on this alpha (default to 1)
  137. classify = function(message, alpha = 1) {
  138. # Initializing the probability product
  139. p_spam = mean(tidy_train$label == "spam")
  140. p_ham = mean(tidy_train$label == "ham")
  141. # Splitting and cleaning the new message
  142. # This is the same cleaning procedure used on the training messages
  143. clean_message = str_to_lower(message) %>%
  144. str_squish %>%
  145. str_replace_all("[[:punct:]]", "") %>%
  146. str_replace_all("[\u0094\u0092\u0096\n\t]", "") %>% # Unicode characters
  147. str_replace_all("[[:digit:]]", "")
  148. words = str_split(clean_message, " ")[[1]]
  149. # There is a possibility that there will be words that don't appear
  150. # in the training vocabulary, so this must be accounted for
  151. # Find the words that aren't present in the training
  152. new_words = setdiff(vocabulary, words)
  153. # Add them to the word_counts
  154. new_word_probs = tibble(
  155. word = new_words,
  156. spam_prob = 1,
  157. ham_prob = 1
  158. )
  159. # Filter down the probabilities to the words present
  160. # use group by to multiply everything together
  161. present_probs = word_counts %>%
  162. filter(word %in% words) %>%
  163. mutate(
  164. # Calculate the probabilities from the counts
  165. spam_prob = (spam_count + alpha) / (n_spam + alpha * n_vocabulary),
  166. ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
  167. ) %>%
  168. bind_rows(new_word_probs) %>%
  169. pivot_longer(
  170. cols = c("spam_prob", "ham_prob"),
  171. names_to = "label",
  172. values_to = "prob"
  173. ) %>%
  174. group_by(label) %>%
  175. summarize(
  176. wi_prob = prod(prob) # prod is like sum, but with multiplication
  177. )
  178. # Calculate the conditional probabilities
  179. p_spam_given_message = p_spam * (present_probs %>% filter(label == "spam_prob") %>% pull(wi_prob))
  180. p_ham_given_message = p_ham * (present_probs %>% filter(label == "ham_prob") %>% pull(wi_prob))
  181. # Classify the message based on the probability
  182. ifelse(p_spam_given_message >= p_ham_given_message, "spam", "ham")
  183. }
  184. # Use the classify function to classify the messages in the training set
  185. # This takes advantage of vectorization
  186. final_train = tidy_train %>%
  187. mutate(
  188. prediction = map_chr(sms, function(m) { classify(m) })
  189. )
  190. # Results of classification on training
  191. confusion = table(final_train$label, final_train$prediction)
  192. accuracy = (confusion[1,1] + confusion[2,2]) / nrow(final_train)
  193. ```
  194. The Naive Bayes Classifier achieves an accuracy of about 89%. Pretty good! Let's see how well it works on messages that it has never seen before.
  195. # Hyperparameter Tuning
  196. ```{r}
  197. alpha_grid = seq(0.05, 1, by = 0.05)
  198. cv_accuracy = NULL
  199. for (alpha in alpha_grid) {
  200. # Recalculate probabilities based on new alpha
  201. cv_probs = word_counts %>%
  202. mutate(
  203. # Calculate the probabilities from the counts based on new alpha
  204. spam_prob = (spam_count + alpha / (n_spam + alpha * n_vocabulary)),
  205. ham_prob = (ham_count + alpha) / (n_ham + alpha * n_vocabulary)
  206. )
  207. # Predict the classification of each message in cross validation
  208. cv = spam_cv %>%
  209. mutate(
  210. prediction = map_chr(sms, function(m) { classify(m, alpha = alpha) })
  211. )
  212. # Assess the accuracy of the classifier on cross-validation set
  213. confusion = table(cv$label, cv$prediction)
  214. acc = (confusion[1,1] + confusion[2,2]) / nrow(cv)
  215. cv_accuracy = c(cv_accuracy, acc)
  216. }
  217. # Check out what the best alpha value is
  218. tibble(
  219. alpha = alpha_grid,
  220. accuracy = cv_accuracy
  221. )
  222. ```
  223. Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
  224. # Test Set Performance
  225. ```{r}
  226. # Reestablishing the proper parameters
  227. optimal_alpha = 0.1
  228. # Using optimal alpha with training parameters, perform final predictions
  229. spam_test = spam_test %>%
  230. mutate(
  231. prediction = map_chr(sms, function(m) { classify(m, alpha = optimal_alpha)} )
  232. )
  233. confusion = table(spam_test$label, spam_test$prediction)
  234. test_accuracy = (confusion[1,1] + confusion[2,2]) / nrow(spam_test)
  235. test_accuracy
  236. ```
  237. We've achieved an accuracy of 93% in the test set. Not bad!