If you collaborated with anyone, you must include “Collaborated with: FIRSTNAME LASTNAME” at the top of your lab!
Use the following code to generate data:
library(ggplot2)
# generate data
set.seed(302)
n <- 30
x <- sort(runif(n, -3, 3))
y <- 2*x + 2*rnorm(n)
x_test <- sort(runif(n, -3, 3))
y_test <- 2*x_test + 2*rnorm(n)
df_train <- data.frame("x" = x, "y" = y)
df_test <- data.frame("x" = x_test, "y" = y_test)
# store a theme
my_theme <- theme_bw(base_size = 16) +
theme(plot.title = element_text(hjust = 0.5, face = "bold"),
plot.subtitle = element_text(hjust = 0.5))
# generate plots
g_train <- ggplot(df_train, aes(x = x, y = y)) + geom_point() +
xlim(-3, 3) + ylim(min(y, y_test), max(y, y_test)) +
labs(title = "Training Data") + my_theme
g_test <- ggplot(df_test, aes(x = x, y = y)) + geom_point() +
xlim(-3, 3) + ylim(min(y, y_test), max(y, y_test)) +
labs(title = "Test Data") + my_theme
g_train
g_test
1a. For every k in between 1 and 10, fit a degree-k polynomial linear regression model with y
as the response and x
as the explanatory variable(s). (Hint: Use poly()
, as in the lecture slides.)
1b. For each model from (a), record the training error. Then predict y_test
using x_test
and also record the test error.
1c. Present the 10 values for both training error and test error on a single table. Comment on what you notice about the relative magnitudes of training and test error, as well as the trends in both types of error as \(k\) increases.
1d. If you were going to choose a model based on training error, which would you choose? Plot the data, colored by split. Add a line to the plot representing your selection for model fit. Add a subtitle to this plot with the (rounded!) test error. (Hint: See Lecture Slides 9 for example code.)
1e. If you were going to choose a model based on test error, which would you choose? Plot the data, colored by split. Add a line to the plot representing your selection for model fit. Add a subtitle to this plot with the (rounded!) test error.
1f. What do you notice about the shape of the curves from part (d) and (e)? Which model do you think has lower bias? Lower variance? Why?
For this part, note that there are tidyverse methods to perform cross-validation in R (see the rsample
package). However, your goal is to understand and be able to implement the algorithm “by hand”, meaning that automated procedures from the rsample
package, or similar packages, will not be accepted.
To begin, load in the popular penguins
data set from the package palmerpenguins
.
library(palmerpenguins)
data(package = "palmerpenguins")
Our goal here is to predict output class species
using covariates bill_length_mm
, bill_depth_mm
, flipper_length_mm
, and body_mass_g
. All your code should be within a function my_knn_cv
.
Input:
train
: input data framecl
: true class value of your training datak_nn
: integer representing the number of neighborsk_cv
: integer representing the number of foldsPlease note the distinction between k_nn
and k_cv
!
Output: a list with objects
class
: a vector of the predicted class \(\hat{Y}_{i}\) for all observationscv_err
: a numeric with the cross-validation misclassification errorYou will need to include the following steps:
fold
that randomly assigns observations to folds \(1,\ldots,k\) with equal probability. (Hint: see the example code on the slides for k-fold cross validation)knn()
from the class
package to predict the class of the \(i\)th fold using all other folds as the training data.class
as the output of knn()
with the full data as both the training and the test data, and the value cv_error
as the average misclassification rate from your cross validation.Submission: To prove your function works, apply it to the penguins
data. Predict output class species
using covariates bill_length_mm
, bill_depth_mm
, flipper_length_mm
, and body_mass_g
. Use \(5\)-fold cross validation (k_cv = 5
). Use a table to show the cv_err
values for 1-nearest neighbor and 5-nearest neighbors (k_nn = 1
and k_nn = 5
). Comment on which value had lower CV misclassification error and which had lower training set error (compare your output class
to the true class, penguins$species
).