I’m writing this up because Amelia McNamara is working on this really cool project, rebooting the ISLR labs using the tidyverse. This is a meagre attempt to pitch in.

The first to answer the call was David Robinson, who, of course, answered the question comprehensively in less time than it would take me to even ponder the question over a cup of coffee.

I will look at a slightly different question, building on Amelia’s and David’s foundation with an eye on visualization. I know the stuff on cross-validation is coming later in the book, I hope this will be OK.

library("ISLR")
library("class")
library("assertthat")

Attaching package: 'assertthat'
The following object is masked from 'package:tibble':

    has_name
library("tidyverse")
Loading tidyverse: readr
Conflicts with tidy packages -------------------------------------------------------------
filter():   dplyr, stats
has_name(): tibble, assertthat
lag():      dplyr, stats
partial():  purrr, pryr
library("modelr")
library("broom")
library("ggbeeswarm")
library("viridis")

Introduction

Following Amelia, let’s look at the ISLR Caravan example (pp. 164–167).

The goal is to apply KNN to the Caravan dataset from the ISLR package. The first thing I’m going to do is make a copy of it as a tibble, then see what we’ve got.

caravan <- 
  as_tibble(ISLR::Caravan) %>%
  print()
# A tibble: 5,822 × 86
   MOSTYPE MAANTHUI MGEMOMV MGEMLEEF MOSHOOFD MGODRK MGODPR MGODOV MGODGE MRELGE MRELSA
*    <dbl>    <dbl>   <dbl>    <dbl>    <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
1       33        1       3        2        8      0      5      1      3      7      0
2       37        1       2        2        8      1      4      1      4      6      2
3       37        1       2        2        8      0      4      2      4      3      2
4        9        1       3        3        3      2      3      2      4      5      2
5       40        1       4        2       10      1      4      1      4      7      1
6       23        1       2        1        5      0      5      0      5      0      6
7       39        2       3        2        9      2      2      0      5      7      2
8       33        1       2        3        8      0      7      0      2      7      2
9       33        1       2        4        8      0      1      3      6      6      0
10      11        2       3        3        3      3      5      0      2      7      0
# ... with 5,812 more rows, and 75 more variables: MRELOV <dbl>, MFALLEEN <dbl>,
#   MFGEKIND <dbl>, MFWEKIND <dbl>, MOPLHOOG <dbl>, MOPLMIDD <dbl>, MOPLLAAG <dbl>,
#   MBERHOOG <dbl>, MBERZELF <dbl>, MBERBOER <dbl>, MBERMIDD <dbl>, MBERARBG <dbl>,
#   MBERARBO <dbl>, MSKA <dbl>, MSKB1 <dbl>, MSKB2 <dbl>, MSKC <dbl>, MSKD <dbl>,
#   MHHUUR <dbl>, MHKOOP <dbl>, MAUT1 <dbl>, MAUT2 <dbl>, MAUT0 <dbl>, MZFONDS <dbl>,
#   MZPART <dbl>, MINKM30 <dbl>, MINK3045 <dbl>, MINK4575 <dbl>, MINK7512 <dbl>,
#   MINK123M <dbl>, MINKGEM <dbl>, MKOOPKLA <dbl>, PWAPART <dbl>, PWABEDR <dbl>,
#   PWALAND <dbl>, PPERSAUT <dbl>, PBESAUT <dbl>, PMOTSCO <dbl>, PVRAAUT <dbl>,
#   PAANHANG <dbl>, PTRACTOR <dbl>, PWERKT <dbl>, PBROM <dbl>, PLEVEN <dbl>,
#   PPERSONG <dbl>, PGEZONG <dbl>, PWAOREG <dbl>, PBRAND <dbl>, PZEILPL <dbl>,
#   PPLEZIER <dbl>, PFIETS <dbl>, PINBOED <dbl>, PBYSTAND <dbl>, AWAPART <dbl>,
#   AWABEDR <dbl>, AWALAND <dbl>, APERSAUT <dbl>, ABESAUT <dbl>, AMOTSCO <dbl>,
#   AVRAAUT <dbl>, AAANHANG <dbl>, ATRACTOR <dbl>, AWERKT <dbl>, ABROM <dbl>,
#   ALEVEN <dbl>, APERSONG <dbl>, AGEZONG <dbl>, AWAOREG <dbl>, ABRAND <dbl>,
#   AZEILPL <dbl>, APLEZIER <dbl>, AFIETS <dbl>, AINBOED <dbl>, ABYSTAND <dbl>,
#   Purchase <fctr>

Yikes! That’s a lot of variables. Following Amelia, let’s standardise the numeric variables of the dataframe.

caravan_standard <- 
  caravan %>%
  select(-Purchase) %>%
  dmap(~as.vector(scale(.x))) %>%
  print() 
dmap() is deprecated. Please use the new colwise family in dplyr.
E.g., summarise_all(), mutate_all(), etc.
# A tibble: 5,822 × 85
       MOSTYPE   MAANTHUI    MGEMOMV    MGEMLEEF   MOSHOOFD     MGODRK     MGODPR
         <dbl>      <dbl>      <dbl>       <dbl>      <dbl>      <dbl>      <dbl>
1   0.68084775 -0.2725565  0.4066617 -1.21685949  0.7793384 -0.6942510  0.2174254
2   0.99221162 -0.2725565 -0.8594262 -1.21685949  0.7793384  0.3025256 -0.3653787
3   0.99221162 -0.2725565 -0.8594262 -1.21685949  0.7793384 -0.6942510 -0.3653787
4  -1.18733547 -0.2725565  0.4066617  0.01075374 -0.9708962  1.2993023 -0.9481828
5   1.22573452 -0.2725565  1.6727497 -1.21685949  1.4794323  0.3025256 -0.3653787
6  -0.09756193 -0.2725565 -0.8594262 -2.44447272 -0.2708024 -0.6942510  0.2174254
7   1.14789355  2.1914562  0.4066617 -1.21685949  1.1293853  1.2993023 -1.5309868
8   0.68084775 -0.2725565 -0.8594262  0.01075374  0.7793384 -0.6942510  1.3830335
9   0.68084775 -0.2725565 -0.8594262  1.23836697  0.7793384 -0.6942510 -2.1137909
10 -1.03165354  2.1914562  0.4066617  0.01075374 -0.9708962  2.2960789  0.2174254
# ... with 5,812 more rows, and 78 more variables: MGODOV <dbl>, MGODGE <dbl>,
#   MRELGE <dbl>, MRELSA <dbl>, MRELOV <dbl>, MFALLEEN <dbl>, MFGEKIND <dbl>,
#   MFWEKIND <dbl>, MOPLHOOG <dbl>, MOPLMIDD <dbl>, MOPLLAAG <dbl>, MBERHOOG <dbl>,
#   MBERZELF <dbl>, MBERBOER <dbl>, MBERMIDD <dbl>, MBERARBG <dbl>, MBERARBO <dbl>,
#   MSKA <dbl>, MSKB1 <dbl>, MSKB2 <dbl>, MSKC <dbl>, MSKD <dbl>, MHHUUR <dbl>,
#   MHKOOP <dbl>, MAUT1 <dbl>, MAUT2 <dbl>, MAUT0 <dbl>, MZFONDS <dbl>, MZPART <dbl>,
#   MINKM30 <dbl>, MINK3045 <dbl>, MINK4575 <dbl>, MINK7512 <dbl>, MINK123M <dbl>,
#   MINKGEM <dbl>, MKOOPKLA <dbl>, PWAPART <dbl>, PWABEDR <dbl>, PWALAND <dbl>,
#   PPERSAUT <dbl>, PBESAUT <dbl>, PMOTSCO <dbl>, PVRAAUT <dbl>, PAANHANG <dbl>,
#   PTRACTOR <dbl>, PWERKT <dbl>, PBROM <dbl>, PLEVEN <dbl>, PPERSONG <dbl>,
#   PGEZONG <dbl>, PWAOREG <dbl>, PBRAND <dbl>, PZEILPL <dbl>, PPLEZIER <dbl>,
#   PFIETS <dbl>, PINBOED <dbl>, PBYSTAND <dbl>, AWAPART <dbl>, AWABEDR <dbl>,
#   AWALAND <dbl>, APERSAUT <dbl>, ABESAUT <dbl>, AMOTSCO <dbl>, AVRAAUT <dbl>,
#   AAANHANG <dbl>, ATRACTOR <dbl>, AWERKT <dbl>, ABROM <dbl>, ALEVEN <dbl>,
#   APERSONG <dbl>, AGEZONG <dbl>, AWAOREG <dbl>, ABRAND <dbl>, AZEILPL <dbl>,
#   APLEZIER <dbl>, AFIETS <dbl>, AINBOED <dbl>, ABYSTAND <dbl>

Now, let’s follow David by using k-fold cross-validation.

So, I sat here staring at the screen for twenty minutes, because I could not see how to go forward with modelr’s framework for cross-validation using knn(); I could not see how to get there from here. So I went to run some errands, and a solution appeared (as happens from time to time).

The problem (I think) is that the API to the knn() function is different than for the lm() function. My solution is to back-up, and to write a function to wrap to the knn() function so that the API will be “close enough”. As I am starting to learn, “write a function” seems to be the way out of a lot of R pickles (and into others).

To act like lm(), we need to keep the target variable in a data-frame alongside the predictor variables. So let’s do that.

caravan_standard_new <- 
  caravan %>%
  dmap_if(is.numeric, ~as.vector(scale(.x))) %>%
  print() 
dmap_if() is deprecated. Please use the new colwise family in dplyr.
E.g., summarise_if(), mutate_if(), etc.
dmap() is deprecated. Please use the new colwise family in dplyr.
E.g., summarise_all(), mutate_all(), etc.
# A tibble: 5,822 × 86
       MOSTYPE   MAANTHUI    MGEMOMV    MGEMLEEF   MOSHOOFD     MGODRK     MGODPR
*        <dbl>      <dbl>      <dbl>       <dbl>      <dbl>      <dbl>      <dbl>
1   0.68084775 -0.2725565  0.4066617 -1.21685949  0.7793384 -0.6942510  0.2174254
2   0.99221162 -0.2725565 -0.8594262 -1.21685949  0.7793384  0.3025256 -0.3653787
3   0.99221162 -0.2725565 -0.8594262 -1.21685949  0.7793384 -0.6942510 -0.3653787
4  -1.18733547 -0.2725565  0.4066617  0.01075374 -0.9708962  1.2993023 -0.9481828
5   1.22573452 -0.2725565  1.6727497 -1.21685949  1.4794323  0.3025256 -0.3653787
6  -0.09756193 -0.2725565 -0.8594262 -2.44447272 -0.2708024 -0.6942510  0.2174254
7   1.14789355  2.1914562  0.4066617 -1.21685949  1.1293853  1.2993023 -1.5309868
8   0.68084775 -0.2725565 -0.8594262  0.01075374  0.7793384 -0.6942510  1.3830335
9   0.68084775 -0.2725565 -0.8594262  1.23836697  0.7793384 -0.6942510 -2.1137909
10 -1.03165354  2.1914562  0.4066617  0.01075374 -0.9708962  2.2960789  0.2174254
# ... with 5,812 more rows, and 79 more variables: MGODOV <dbl>, MGODGE <dbl>,
#   MRELGE <dbl>, MRELSA <dbl>, MRELOV <dbl>, MFALLEEN <dbl>, MFGEKIND <dbl>,
#   MFWEKIND <dbl>, MOPLHOOG <dbl>, MOPLMIDD <dbl>, MOPLLAAG <dbl>, MBERHOOG <dbl>,
#   MBERZELF <dbl>, MBERBOER <dbl>, MBERMIDD <dbl>, MBERARBG <dbl>, MBERARBO <dbl>,
#   MSKA <dbl>, MSKB1 <dbl>, MSKB2 <dbl>, MSKC <dbl>, MSKD <dbl>, MHHUUR <dbl>,
#   MHKOOP <dbl>, MAUT1 <dbl>, MAUT2 <dbl>, MAUT0 <dbl>, MZFONDS <dbl>, MZPART <dbl>,
#   MINKM30 <dbl>, MINK3045 <dbl>, MINK4575 <dbl>, MINK7512 <dbl>, MINK123M <dbl>,
#   MINKGEM <dbl>, MKOOPKLA <dbl>, PWAPART <dbl>, PWABEDR <dbl>, PWALAND <dbl>,
#   PPERSAUT <dbl>, PBESAUT <dbl>, PMOTSCO <dbl>, PVRAAUT <dbl>, PAANHANG <dbl>,
#   PTRACTOR <dbl>, PWERKT <dbl>, PBROM <dbl>, PLEVEN <dbl>, PPERSONG <dbl>,
#   PGEZONG <dbl>, PWAOREG <dbl>, PBRAND <dbl>, PZEILPL <dbl>, PPLEZIER <dbl>,
#   PFIETS <dbl>, PINBOED <dbl>, PBYSTAND <dbl>, AWAPART <dbl>, AWABEDR <dbl>,
#   AWALAND <dbl>, APERSAUT <dbl>, ABESAUT <dbl>, AMOTSCO <dbl>, AVRAAUT <dbl>,
#   AAANHANG <dbl>, ATRACTOR <dbl>, AWERKT <dbl>, ABROM <dbl>, ALEVEN <dbl>,
#   APERSONG <dbl>, AGEZONG <dbl>, AWAOREG <dbl>, ABRAND <dbl>, AZEILPL <dbl>,
#   APLEZIER <dbl>, AFIETS <dbl>, AINBOED <dbl>, ABYSTAND <dbl>, Purchase <fctr>

Now, let’s work on the wrapper for the knn() function.

#' gets \code{class::knn()} to play nice with modelr
#'
#' @param train       dataframe, with (scaled) numeric columns for predictors
#'                    and a factor column for the target
#' @param test        dataframe, with (scaled) numeric columns for predictors
#'                    and a factor column for the target
#' @param str_target  string, indicated target column of test and train
#'                    dataframe
#' @param ...         arguments passed on to \code{class::knn()}
#'
#' @return like \code{class::knn()}, factor of classifications of test set. 
#'         \code{doubt} will be returned as \code{NA}.
#'
knn_new <- function(train, test, str_target, ...){
  
  # lets us use "resample"
  train <- as.data.frame(train)
  test <- as.data.frame(test)
  
  # yes, I should be able to do this using NSE, but I forgot...
  assertthat::assert_that(str_target %in% names(train))
  assertthat::assert_that(str_target %in% names(test)) # may not need this
  
  # get target vector for train dataframe
  target_train <- train[[str_target]]
  
  # remove target column from both dataframes
  train[[str_target]] <- NULL
  test[[str_target]] <- NULL
  
  class::knn(train = train, test = test, cl = target_train, ...)
}

Let’s see if this thing works…

Using the standard method:

test_caravan = caravan_standard %>%
  slice(1:1000)
train_caravan = caravan_standard %>%
  slice(1001:5822)

Purchase = caravan %>%
  select(Purchase)

test_purchase = Purchase %>%
  slice(1:1000) %>%
  .$Purchase

train_purchase = Purchase %>%
  slice(1001:5822) %>%
  .$Purchase
set.seed(1)
knn_pred = knn(train_caravan, test_caravan, train_purchase, k=1)
mean(test_purchase != knn_pred) # KNN error rate
[1] 0.118
mean(test_purchase != "No")  
[1] 0.059

Now, let’s try with the “new” function:

test_caravan_new = caravan_standard_new %>%
  slice(1:1000)
train_caravan_new = caravan_standard_new %>%
  slice(1001:5822)

set.seed(1)
knn_pred_new = knn_new(train_caravan_new, test_caravan_new, "Purchase", k=1)
mean(test_purchase != knn_pred_new) # KNN error rate
[1] 0.118
mean(test_purchase != "No") 
[1] 0.059

Promising…. just to make (more) sure:

all(knn_pred == knn_pred_new)
[1] TRUE

Whew! Next let’s use modelr to do some cross-validations:

I suspect I am doing something bad here by not requiring that the proportions of the levels of the response variable are consistent among the train and test sets. I’ll leave that as an exercise for later.

# more hackery
get_resample_column <- function(df, str_var){
  df <- as.data.frame(df)
  df[[str_var]]
}

caravan_summary <- 
  caravan_standard_new %>%
  crossv_kfold(k = 20) %>%
  mutate(
    pred = map2(train, test, knn_new, "Purchase", k = 1),
    resp = map(test, get_resample_column, "Purchase")
  ) %>%
  unnest(pred, resp) %>%
  group_by(.id, pred, resp) %>%
  summarise(count = n()) %>%
  print()
Source: local data frame [78 x 4]
Groups: .id, pred [?]

     .id   pred   resp count
   <chr> <fctr> <fctr> <int>
1     01     No     No   259
2     01     No    Yes    14
3     01    Yes     No    16
4     01    Yes    Yes     3
5     02     No     No   258
6     02     No    Yes    18
7     02    Yes     No    14
8     02    Yes    Yes     2
9     03     No     No   251
10    03     No    Yes    23
# ... with 68 more rows

At this point, we could visualize the confusion matrix over all of the cross-validations.

caravan_summary %>%
  mutate(k = "1") %>%
  ggplot(aes(x = k, y = count)) +
  geom_beeswarm(alpha = 0.5) + 
  facet_grid(pred ~ resp, scales = "free")

I don’t know if such a visualization is a useful thing or not - there are doubtless things that can be done to make it more useful, but this may revealed (to me at least) only with coffee. At the very least, I ought to label the facet axes to show which is prediction and which is response.

This method can be extended to looking at different values of \(k\), as well. I will have to get to that later.

devtools::session_info()
Session info ----------------------------------------------------------------------------
 setting  value                       
 version  R version 3.3.1 (2016-06-21)
 system   x86_64, darwin13.4.0        
 ui       RStudio (1.0.44)            
 language (EN)                        
 collate  en_US.UTF-8                 
 tz       America/Chicago             
 date     2016-12-07                  
Packages --------------------------------------------------------------------------------
 package         * version     date       source                            
 assertthat      * 0.1         2013-12-06 CRAN (R 3.3.0)                    
 backports         1.0.4       2016-10-24 cran (@1.0.4)                     
 beeswarm          0.2.3       2016-04-25 CRAN (R 3.3.0)                    
 broom           * 0.4.1       2016-06-24 cran (@0.4.1)                     
 class           * 7.3-14      2015-08-30 CRAN (R 3.3.1)                    
 codetools         0.2-14      2015-07-15 CRAN (R 3.3.1)                    
 colorspace        1.2-6       2015-03-11 CRAN (R 3.3.0)                    
 DBI               0.5-1       2016-09-10 CRAN (R 3.3.0)                    
 devtools        * 1.12.0.9000 2016-11-21 Github (hadley/devtools@2e3c4b6)  
 digest            0.6.10      2016-08-02 cran (@0.6.10)                    
 dplyr           * 0.5.0       2016-06-24 cran (@0.5.0)                     
 evaluate          0.10        2016-10-11 cran (@0.10)                      
 foreign           0.8-66      2015-08-19 CRAN (R 3.3.1)                    
 ggbeeswarm      * 0.5.0       2016-02-21 CRAN (R 3.3.0)                    
 ggplot2         * 2.1.0       2016-03-01 CRAN (R 3.3.0)                    
 gridExtra         2.2.1       2016-02-29 CRAN (R 3.3.0)                    
 gtable            0.2.0       2016-02-26 CRAN (R 3.3.0)                    
 htmlDocumentIJL   0.0.0.9000  2016-09-04 local                             
 htmltools         0.3.5       2016-03-21 CRAN (R 3.3.0)                    
 htmlwidgets       0.6         2016-02-25 CRAN (R 3.3.0)                    
 ISLR            * 1.0         2013-06-11 CRAN (R 3.3.0)                    
 jsonlite          1.1         2016-09-14 CRAN (R 3.3.0)                    
 knitr           * 1.15.1      2016-11-22 cran (@1.15.1)                    
 labeling          0.3         2014-08-23 CRAN (R 3.3.0)                    
 lattice           0.20-33     2015-07-14 CRAN (R 3.3.1)                    
 lazyeval          0.2.0.9000  2016-09-22 Github (hadley/lazyeval@c155c3d)  
 listviewer      * 1.0         2016-06-15 CRAN (R 3.3.0)                    
 magrittr        * 1.5         2014-11-22 CRAN (R 3.3.0)                    
 memoise           1.0.0       2016-01-29 CRAN (R 3.3.0)                    
 mnormt            1.5-4       2016-03-09 cran (@1.5-4)                     
 modelr          * 0.1.0       2016-08-31 CRAN (R 3.3.0)                    
 munsell           0.4.3       2016-02-13 CRAN (R 3.3.0)                    
 nlme              3.1-128     2016-05-10 CRAN (R 3.3.1)                    
 pkgbuild          0.0.0.9000  2016-11-21 Github (r-pkgs/pkgbuild@65eace0)  
 pkgload           0.0.0.9000  2016-11-21 Github (r-pkgs/pkgload@def2b10)   
 plyr              1.8.4       2016-06-08 cran (@1.8.4)                     
 pryr            * 0.1.2       2015-06-20 CRAN (R 3.3.0)                    
 psych             1.6.9       2016-09-17 CRAN (R 3.3.0)                    
 purrr           * 0.2.2.9000  2016-11-21 Github (hadley/purrr@5360143)     
 R6                2.2.0       2016-10-05 cran (@2.2.0)                     
 Rcpp              0.12.8      2016-11-17 cran (@0.12.8)                    
 readr           * 1.0.0       2016-08-03 CRAN (R 3.3.0)                    
 reshape2          1.4.2       2016-10-22 CRAN (R 3.3.1)                    
 rmarkdown         1.2.9000    2016-12-01 Github (rstudio/rmarkdown@de08391)
 rprojroot         1.1         2016-10-29 cran (@1.1)                       
 rsconnect         0.5         2016-10-17 CRAN (R 3.3.1)                    
 rstudioapi        0.6         2016-06-27 CRAN (R 3.3.0)                    
 scales            0.4.0       2016-02-26 CRAN (R 3.3.0)                    
 stringi           1.1.2       2016-10-01 CRAN (R 3.3.0)                    
 stringr           1.1.0       2016-08-19 CRAN (R 3.3.0)                    
 tibble          * 1.2         2016-08-26 CRAN (R 3.3.0)                    
 tidyr           * 0.6.0.9000  2016-09-07 Github (hadley/tidyr@3c9335b)     
 tidyverse       * 0.0.0.9000  2016-09-07 Github (hadley/tidyverse@6ca05a7) 
 user2016docdemo * 0.0.0.9000  2016-09-06 local                             
 utilrSE         * 0.1.99      2016-11-15 local                             
 vipor             0.4.3       2016-07-27 CRAN (R 3.3.0)                    
 viridis         * 0.3.4       2016-03-12 CRAN (R 3.3.0)                    
 withr             1.0.2       2016-06-20 CRAN (R 3.3.0)                    
 yaml              2.1.14      2016-11-12 cran (@2.1.14)