Rで何かをしたり、読書をするブログ

政府統計の総合窓口のデータや、OECDやUCIやのデータを使って、Rの練習をしています。ときどき、読書記録も載せています。

UCI Machine Learning Repository の Spambase のデータの分析3 - Random Forest で分類

www.crosshyou.info

の続きです。今回はランダムフォレストで分類してみます。

今回も tidymodels のワークフローでやります。

recipe() 関数でレシピを作成します。

ランダムフォレストのモデルは、rand_forest() 関数で作ります。

レシピとモデルを統合してワークフローを作成します。

チューングリッドの作成をします。

チューニングの実行をします。前回の glmnet エンジンでの Elastic-Net Logistic Regression と比べると長い時間、10分以上かかりました。

最良のパラメータを確認します。

この最良のパラメータで最終ワークフローを確定します。

fit() 関数でモデルをフィットさせます。

テスト用のデータで予測しましょう。

混合行列を作って結果を確認します。

間違いは52個です。正解率はどうでしょうか?

正解率は 95.5% でした。前回の Elastic-Net Logistic Regression よりもいい正解率です。

ROCのAUCを計算します。

 

0.983です。

ROC曲線を描きます。

vup パラメータの vip() 関数で重要な変数を確認します。

! マークや $ マーク、remove, free などが重要な変数です。前回のモデルと同じような感じですね。

今回は以上です。

はじめから読むには、

www.crosshyou.info

です。

今回のコードは以下になります。

#
# 2. Random Forest による分類
# 2.1 レシピを作成
rf_rec <- recipe(spam ~ ., data = df_train) |> 
  step_normalize(all_numeric_predictors())
#
# 2.2 モデルを作成
rf_mod <- rand_forest(
  mtry = tune(),
  min_n = tune(),
  trees = 1000
) |> 
  set_engine("ranger", importance = "impurity") |> 
  set_mode("classification")
#
# 2.3 ワークフローを作成
rf_wf <- workflow() |> 
  add_recipe(rf_rec) |> 
  add_model(rf_mod)
#
# 2.4 チューングリッドの作成
rf_grid <- grid_regular(
  mtry(range = c(3, 9)),
  min_n(range = c(2, 30)),
  levels = 7
)
#
# 2.5 チューニングの実行
set.seed(1)
rf_tuned <- tune_grid(
  rf_wf,
  grid = rf_grid,
  resamples = cvfolds,
  metrics = metric_set(roc_auc)
)
#
# 2.6 最良のパラメータ
rf_params <- select_best(rf_tuned, metric = "roc_auc")
rf_params
#
# 2.7 最終ワークフローを作成
rf_final_wf <- finalize_workflow(rf_wf, rf_params)
#
# 2.8 モデルをフィット
rf_fit <- fit(rf_final_wf, data = df_train)
#
# 2.9 テスト用のデータで予測
rf_pred <- bind_cols(
  df_test |> select(spam),
  predict(rf_fit, new_data = df_test, type = "class"),
  predict(rf_fit, new_data = df_test, type = "prob")
)
#
# 2.10 混合行列
table(rf_pred$.pred_class, rf_pred$spam)
#
# 2.11 正解率
rf_pred |> 
  summarize(accuracy = mean(spam == .pred_class))
#
# 2.12 ROCのAUC
rf_pred |> roc_auc(truth = spam, .pred_0)
#
# 2.13 ROC曲線
rf_roc <- rf_pred |> roc_curve(truth = spam, .pred_0)
rf_roc |> 
  ggplot(aes(x = 1 - specificity, y = sensitivity)) +
  geom_line(color = "blue", linewidth = 1) +
  geom_abline(color = "red", lty = "dashed", linewidth = 1) +
  labs(title = "Random Forest ROC Curve",
       x = "False Positive Rate",
       y = "True Positive Rate") +
  theme_minimal()
#
# 重要な変数の可視化
library(vip)
rf_fit |> extract_fit_engine() |> 
  vip(num_features = 10) +
  theme_minimal()
#

(冒頭の画像は Bing Image Creator で生成しました。プロンプトは、Lovely landscape of natural field, close up of Veronica persica flowers, blue sky, photo です。)