Queer European MD passionate about IT

Mission475Solutions.Rmd 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. ---
  2. title: "Conditional Probability in R: Guided Project Solutions"
  3. output: html_document
  4. ---
  5. ```{r, warning = FALSE, message = FALSE }
  6. library(tidyverse)
  7. set.seed(1)
  8. ```
  9. # Introduction
  10. This analysis is an application of what we've learned in Dataquest's Conditional Probability course. Using a dataset of pre-labeled SMS messages, we'll create a spam filter using the Naive Bayes algorithm.
  11. # Data
  12. ```{r}
  13. spam = read.csv("./data/SMSSpamCollection", sep = "\t", header = F)
  14. colnames(spam) = c("label", "sms")
  15. ```
  16. The `spam` dataset has `r nrow(spam)` rows and `r ncol(spam)` columns. Of these messages, `r mean(spam$label == "ham") * 100`% of them are not classified as spam, the rest are spam.
  17. # Dividing Up Into Training and Test Sets
  18. ```{r}
  19. n = nrow(spam)
  20. n.training = 2547
  21. n.cv = 318
  22. n.test = 319
  23. # Create the random indices for training set
  24. train.indices = sample(1:n, size = n.training, replace = FALSE)
  25. # Get indices not used by the training set
  26. remaining.indices = setdiff(1:n, train.indices)
  27. # Remaining indices are already randomized, just allocate correctly
  28. cv.indices = remaining.indices[1:318]
  29. test.indices = remaining.indices[319:length(remaining.indices)]
  30. # Use the indices to create each of the datasets
  31. spam.train = spam[train.indices,]
  32. spam.cv = spam[cv.indices,]
  33. spam.test = spam[test.indices,]
  34. # Sanity check: are the ratios of ham to spam relatively constant?
  35. print(mean(spam.train$label == "ham"))
  36. print(mean(spam.cv$label == "ham"))
  37. print(mean(spam.test$label == "ham"))
  38. ```
  39. The number of ham messages in each dataset is relatively close to the original 87%. These datasets look good for future analysis.
  40. # Data Cleaning
  41. ```{r}
  42. # To lowercase, removal of punctuation
  43. tidy.train = spam.train %>%
  44. mutate(
  45. sms = tolower(sms),
  46. sms = str_replace_all(sms, "[[:punct:]]", ""),
  47. sms = str_replace_all(sms, "[[:digit:]]", " "),
  48. sms = str_replace_all(sms, "[\u0094\u0092\n\t]", " ")
  49. )
  50. # Creating the vocabulary
  51. vocabulary = NULL
  52. messages = pull(tidy.train, sms)
  53. # Iterate through the messages and add to the vocabulary
  54. for (m in messages) {
  55. words = str_split(m, " ")[[1]]
  56. words = words[!words %in% ""]
  57. vocabulary = c(vocabulary, words)
  58. }
  59. vocabulary = unique(vocabulary)
  60. ```
  61. # Calculating Constants and Parameters
  62. ```{r}
  63. # Calculating Constants
  64. # Mean of a vector of logicals is a percentage
  65. p.spam = mean(tidy.train$label == "spam")
  66. p.ham = mean(tidy.train$label == "ham")
  67. # Isolate the spam and ham messages
  68. spam.messages = tidy.train %>%
  69. filter(label == "spam") %>%
  70. pull("sms")
  71. ham.messages = tidy.train %>%
  72. filter(label == "ham") %>%
  73. pull("sms")
  74. spam.words = NULL
  75. for (sm in spam.messages) {
  76. words = str_split(sm, " ")[[1]]
  77. spam.words = c(spam.words, words)
  78. }
  79. ham.words = NULL
  80. for (hm in ham.messages) {
  81. words = str_split(hm, " ")[[1]]
  82. ham.words = c(ham.words, words)
  83. }
  84. n.spam = length(unique(spam.words))
  85. n.ham = length(unique(ham.words))
  86. n.vocabulary = length(vocabulary)
  87. alpha = 1
  88. ```
  89. ```{r}
  90. # Calculating Parameters
  91. spam.counts = list()
  92. ham.counts = list()
  93. spam.probs = list()
  94. ham.probs = list()
  95. # This might take a while to run with so many words
  96. for (vocab in vocabulary) {
  97. # Initialize the counts for the word
  98. spam.counts[[vocab]] = 0
  99. ham.counts[[vocab]] = 0
  100. # Break up the message and count how many times word appears
  101. for (sm in spam.messages) {
  102. words = str_split(sm, " ")[[1]]
  103. spam.counts[[vocab]] = spam.counts[[vocab]] + sum(words == vocab)
  104. }
  105. for (hm in ham.messages) {
  106. words = str_split(hm, " ")[[1]]
  107. ham.counts[[vocab]] = ham.counts[[vocab]] + sum(words == vocab)
  108. }
  109. # Use the counts to calculate the probability
  110. spam.probs[[vocab]] = (spam.counts[[vocab]] + alpha) / (n.spam + alpha * n.vocabulary)
  111. ham.probs[[vocab]] = (ham.counts[[vocab]] + alpha) / (n.ham + alpha * n.vocabulary)
  112. }
  113. ```
  114. # Classifying New Messages
  115. ```{r}
  116. classify = function(message) {
  117. # Initializing the probability product
  118. p.spam.given.message = p.spam
  119. p.ham.given.message = p.ham
  120. # Splitting and cleaning the new message
  121. clean.message = tolower(message)
  122. clean.message = str_replace_all(clean.message, "[[:punct:]]", "")
  123. clean.message = str_replace_all(clean.message, "[[:digit:]]", " ")
  124. clean.message = str_replace_all(clean.message, "[\u0094\u0092\n\t]", " ")
  125. words = str_split(clean.message, " ")[[1]]
  126. for (word in words) {
  127. # Extra check if word is not in vocabulary
  128. wi.spam.prob = ifelse(word %in% vocabulary,
  129. spam.probs[[word]],
  130. 1)
  131. wi.ham.prob = ifelse(word %in% vocabulary,
  132. ham.probs[[word]],
  133. 1)
  134. p.spam.given.message = p.spam.given.message * wi.spam.prob
  135. p.ham.given.message = p.ham.given.message * wi.ham.prob
  136. }
  137. result = case_when(
  138. p.spam.given.message >= p.ham.given.message ~ "spam",
  139. p.spam.given.message < p.ham.given.message ~ "ham")
  140. return(result)
  141. }
  142. final.train = tidy.train %>%
  143. mutate(
  144. prediction = unlist(map(sms, classify))
  145. ) %>%
  146. select(label, prediction, sms)
  147. # Results of classification on training
  148. confusion = table(final.train$label, final.train$prediction)
  149. accuracy = (confusion[1,1] + confusion[2,2]) / nrow(final.train)
  150. ```
  151. Roughly, the classifier achieves about 97% accuracy on the training set. We aren't interested in how well the classifier performs with training data though, the classifier has already "seen" all of these messages.
  152. # Hyperparameter Tuning
  153. ```{r}
  154. alpha.grid = seq(0.1, 1, by = 0.1)
  155. cv.accuracy = NULL
  156. for (a in alpha.grid) {
  157. spam.probs = list()
  158. ham.probs = list()
  159. # This might take a while to run with so many words
  160. for (vocab in vocabulary) {
  161. # Use the counts to calculate the probability
  162. spam.probs[[vocab]] = (spam.counts[[vocab]] + a) / (n.spam + a * n.vocabulary)
  163. ham.probs[[vocab]] = (ham.counts[[vocab]] + a) / (n.ham + a * n.vocabulary)
  164. }
  165. cv = spam.cv %>%
  166. mutate(
  167. prediction = unlist(map(sms, classify))
  168. ) %>%
  169. select(label, prediction, sms)
  170. confusion = table(cv$label, cv$prediction)
  171. acc = (confusion[1,1] + confusion[2,2]) / nrow(cv)
  172. cv.accuracy = c(cv.accuracy, acc)
  173. }
  174. cv.check = tibble(
  175. alpha = alpha.grid,
  176. accuracy = cv.accuracy
  177. )
  178. cv.check
  179. ```
  180. Judging from the cross-validation set, higher $\alpha$ values cause the accuracy to decrease. We'll go with $\alpha = 0.1$ since it produces the highest cross-validation prediction accuracy.
  181. # Test Set Performance
  182. ```{r}
  183. # Reestablishing the proper parameters
  184. optimal.alpha = 0.1
  185. for (a in alpha.grid) {
  186. spam.probs = list()
  187. ham.probs = list()
  188. # This might take a while to run with so many words
  189. for (vocab in vocabulary) {
  190. # Use the counts to calculate the probability
  191. spam.probs[[vocab]] = (spam.counts[[vocab]] + optimal.alpha) / (n.spam + optimal.alpha * n.vocabulary)
  192. ham.probs[[vocab]] = (ham.counts[[vocab]] + optimal.alpha) / (n.ham + optimal.alpha * n.vocabulary)
  193. }
  194. }
  195. spam.test = spam.test %>%
  196. mutate(
  197. prediction = unlist(map(sms, classify))
  198. ) %>%
  199. select(label, prediction, sms)
  200. confusion = table(spam.test$label, spam.test$prediction)
  201. test.accuracy = (confusion[1,1] + confusion[2,2]) / nrow(cv)
  202. test.accuracy
  203. ```
  204. We've achieved an accuracy of 93% in the test set. Not bad!