Method Comparison

Search algorithm, LASSO regression, and {xgboost}

Author

Jessica Helmer

Published

April 2, 2026

Code
library(tidyverse)
library(gt)

set_theme(theme_classic(base_size = 16, paper = "#eceadf"))
Code
v5_dat <- readRDS(here::here("Data", "v5_dat.rds"))

v5 <- readRDS(here::here("Data", "Search Algorithms", "v5test_dat.rds"))
xgb <- readRDS(here::here("Data", "xgboost", "xgb_item-importance.rds"))
lasso <- readRDS(here::here("Data", "lasso_regression.rds"))

startup_costs <- readRDS(here::here("Data", "startup_costs.rds"))

if (!dir.exists(here::here("Data", "Method Comparison"))) dir.create(here::here("Data", "Method Comparison"))
Code
source(here::here("Scripts", "get_startup.R"))
Code
methods <- list("v5", "lasso", "xgb")

v5_ranked <- v5 |> 
  select(item, v5_rank = rep) |>
  arrange(v5_rank)

xgb_ranked <- xgb |>
  # because higher importance should correspond to lower (better) rank
  # adding one to match v5
  mutate(xgb_rank = n() - rank(Importance) + 1) |>
  select(xgb_rank, item = Variable) |>
  arrange(xgb_rank)

lasso_ranked <- lasso |>
  mutate(lasso_rank = n() - rank(Importance) + 1) |>
  select(lasso_rank, item = Variable) |>
  arrange(lasso_rank)

ranks <- v5_ranked |>
  full_join(xgb_ranked,
            by = "item") |>
  full_join(lasso_ranked,
            by = "item") |>
  inner_join(v5_dat |>
               select(task, item, time) |>
               unique(),
             by = "item")

Method Comparison

Correlations between item ranks from each method.

Spearman Rank

Code
ranks |>
  select(`Search Algorithm` = v5_rank,
         `xgboost` = xgb_rank,
         LASSO = lasso_rank) |>
  cor(method = "spearman") |>
  as.data.frame() |>
  rownames_to_column() |>
  gt::gt() |>
  gt::fmt_number() |>
  tab_style(style = list(
              cell_fill(color = "#eceadf"),
              cell_borders(sides = c("left", "right", "top", "bottom"), 
                          color = scales::col_darker("#eceadf", 15))),
            locations = list(cells_body(), cells_column_labels(), cells_title(), cells_stub(), cells_stubhead()))
Search Algorithm xgboost LASSO
Search Algorithm 1.00 0.23 0.49
xgboost 0.23 1.00 0.40
LASSO 0.49 0.40 1.00

Kendall’s Tau

Code
ranks |>
  select(`Search Algorithm` = v5_rank,
         `xgboost` = xgb_rank,
         LASSO = lasso_rank) |>
  cor(method = "kendall") |>
  as.data.frame() |>
  rownames_to_column() |>
  gt::gt() |>
  gt::fmt_number() |>
  tab_style(style = list(
              cell_fill(color = "#eceadf"),
              cell_borders(sides = c("left", "right", "top", "bottom"), 
                          color = scales::col_darker("#eceadf", 15))),
            locations = list(cells_body(), cells_column_labels(), cells_title(), cells_stub(), cells_stubhead()))
Search Algorithm xgboost LASSO
Search Algorithm 1.00 0.14 0.34
xgboost 0.14 1.00 0.27
LASSO 0.34 0.27 1.00
Code
ranks |>
  mutate(task = str_split_i(item, "_", 1)) |>
  ggplot(aes(x = v5_rank, y = xgb_rank, label = item, color = task)) +
  geom_text()
ranks |>
  mutate(task = str_split_i(item, "_", 1)) |>
  ggplot(aes(x = v5_rank, y = lasso_rank, label = item, color = task)) +
  geom_text()
ranks |>
  mutate(task = str_split_i(item, "_", 1)) |>
  ggplot(aes(x = lasso_rank, y = xgb_rank, label = item, color = task)) +
  geom_text()

Method Timing

Code
walk(methods, \(x) ranks |>
      mutate(task = str_split_i(item, "_", 1)) |>
      arrange(paste0(x, "_rank") |> sym() |> eval()) |>
      mutate(.by = task,
             time_w_startup = time + ifelse(item == first(item), get_startup(first(task)), 0)) |>
      mutate(cum_time = cumsum(time_w_startup)) |>
      assign(x = paste0(x, "_ranks_w_timing"), value = _, env = globalenv()))
Code
v5_r2_w_time <- map(1:99,
    \(i) {
      included_items <- v5_ranks_w_timing |>
        select(item, task, v5_rank, cum_time) |>
        head(i)
      
      inner_join(v5_dat|>
                   select(subject_id, sscore, item, score),
                 included_items,
                 by = "item") |>
        summarize(.by = c(subject_id, task),
                  sscore = first(sscore),
                  score = mean(score)) |>
        mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
        pivot_wider(names_from = task, values_from = score) |>
        drop_na() |>
        summarize(rep = i,
                  item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
                  test_r2 = summary(lm(as.formula(first(formula)),
                                       data = pick(everything())))$r.squared,
                  cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
  list_rbind()

saveRDS(v5_r2_w_time, here::here("Data", "Method Comparison", "v5_r2_w_time.rds"))
Code
xgb_r2_w_time <- map(1:99,
    \(i) {
      included_items <- xgb_ranks_w_timing |>
        select(item, task, v5_rank, cum_time) |>
        head(i)
      
      inner_join(v5_dat |>
                   select(subject_id, sscore, item, score),
                 included_items,
                 by = "item") |>
        summarize(.by = c(subject_id, task),
                  sscore = first(sscore),
                  score = mean(score)) |>
        mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
        pivot_wider(names_from = task, values_from = score) |>
        drop_na() |>
        summarize(rep = i,
                  item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
                  test_r2 = summary(lm(as.formula(first(formula)),
                                       data = pick(everything())))$r.squared,
                  cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
  list_rbind()

saveRDS(xgb_r2_w_time, here::here("Data", "Method Comparison", "xgb_r2_w_time.rds"))
Code
lasso_r2_w_time <- map(1:99,
    \(i) {
      included_items <- lasso_ranks_w_timing |>
        select(item, task, v5_rank, cum_time) |>
        head(i)
      
      inner_join(v5_dat|>
                   select(subject_id, sscore, item, score),
                 included_items,
                 by = "item") |>
        summarize(.by = c(subject_id, task),
                  sscore = first(sscore),
                  score = mean(score)) |>
        mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
        pivot_wider(names_from = task, values_from = score) |>
        drop_na() |>
        summarize(rep = i,
                  item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
                  test_r2 = summary(lm(as.formula(first(formula)),
                                       data = pick(everything())))$r.squared,
                  cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
  list_rbind()

saveRDS(lasso_r2_w_time, here::here("Data", "Method Comparison", "lasso_r2_w_time.rds"))
Code
v5_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "v5_r2_w_time.rds"))
xgb_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "xgb_r2_w_time.rds"))
lasso_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "lasso_r2_w_time.rds"))

walk2(list(v5_r2_w_time, xgb_r2_w_time, lasso_r2_w_time),
      list("Search Algorithm", "xgboost", "LASSO"),
      \(d, name) (d |>
                    mutate(task = str_split_i(item, "_", 1)) |>
                    ggplot(aes(x = cum_time, y = test_r2)) +
                    geom_line() +
                    geom_point(aes(color = task)) +
                    coord_cartesian(ylim = c(.1, .5)) +
                    labs(title = name, y = "R<sup>2</sup>", x = "Time") +
                    theme(axis.title.y = ggtext::element_markdown())) |>
        print())

Code
v5_ranks_w_timing |>
  filter(cum_time <= 15) |>
  count(task, name = "v5") |>
  full_join(xgb_ranks_w_timing |>
              filter(cum_time <= 15) |>
              count(task, name = "xgb"),
            by = "task") |>
  full_join(lasso_ranks_w_timing |>
              filter(cum_time <= 15) |>
              count(task, name = "lasso"),
            by = "task") |>
  arrange(task) |>
  gt(rowname_col = "task") |>
  sub_missing(missing_text = "") |>
  tab_style(style = cell_text(weight = "bold"),
            locations = cells_column_labels())
v5_ranks_w_timing |>
  filter(cum_time <= 15) |>
  count(task, name = "v5") |>
  full_join(xgb_ranks_w_timing |>
              filter(cum_time <= 15) |>
              count(task, name = "xgb"),
            by = "task") |>
  full_join(lasso_ranks_w_timing |>
              filter(cum_time <= 15) |>
              count(task, name = "lasso"),
            by = "task") |>
  mutate(task = str_split_i(task, "\\.", 1)) |>
  summarize(.by = task,
            across(everything(), ~ sum(.x, na.rm = T))) |>
  arrange(task) |>
  gt(rowname_col = "task") |>
  sub_missing(missing_text = "") |>
  sub_zero(zero_text = " ") |>
  tab_style(style = cell_text(weight = "bold"),
            locations = cells_column_labels())
v5 xgb lasso
admc.dr 7 3 3
bu.e 2 5 3
bu.h
5 4
dn.c 10
3
ns 1

ts 3 2 3
v5 xgb lasso
admc 7 3 3
bu 2 10 7
dn 10 3
ns 1
ts 3 2 3

Item Counts Across Tests

Code
walk(methods, \(x) x |>
      paste0("_ranks_w_timing") |>
      sym() |>
      eval() |>
      filter(cum_time <= 15) |>
      pull(item) |>
      assign(x = paste0(x, "_test"),
             value = _,
             env = globalenv()))

ranks |>
  pull(item) |>
  map(\(item) tibble(item = item,
                     count = sum(c(item %in% v5_test,
                                   item %in% xgb_test,
                                   item %in% lasso_test)))) |>
  bind_rows() |>
  filter_out(count == 0) |>
  arrange(desc(count)) |>
  mutate(task = str_split_i(item, "_", 1),
         .before = item)  |>
  gt(rowname_col = c("task", "item")) |>
  tab_style(style = cell_text(weight = "bold"),
            locations = cells_column_labels())
count
admc.dr admc.dr_2 3
admc.dr_9 3
ts ts_7 3
bu.e bu.e_0 3
bu.e_2 3
admc.dr admc.dr_4 2
dn.c dn.c_1 2
dn.c_16 2
dn.c_19 2
ts ts_9 2
ts_2 2
admc.dr admc.dr_8 2
bu.h bu.h_6 2
bu.h_2 2
bu.e bu.e_6 2
dn.c dn.c_13 1
dn.c_10 1
dn.c_24 1
dn.c_12 1
dn.c_18 1
dn.c_14 1
admc.dr admc.dr_1 1
admc.dr_6 1
admc.dr_3 1
ns ns_5 1
dn.c dn.c_23 1
bu.e bu.e_4 1
bu.h bu.h_5 1
bu.h_4 1
bu.h_1 1
bu.h_3 1
bu.h_0 1
bu.e bu.e_7 1
ts ts_8 1
Code
item_properties <- rbind(
  readRDS(here::here("Data", "Denominator Neglect", "dn_dat_c.rds")) |>
    select(item, choice_type, proportion_difference, small_lottery_gold_prop) |>
    unique() |>
    left_join(readRDS(here::here("Data", "Denominator Neglect", "top_items.rds")) |> 
                select(rank, conf_item, harm_item) |>
                pivot_longer(c(conf_item, harm_item),
                             names_to = "choice_type", values_to = "item") |>
                select(-choice_type),
              by = "item") |>
    mutate(item_ct = paste0("dn.c_", rank, "_", choice_type),
           item = paste0("dn.c_", rank),
           .keep = "unused") |>
    nest(.by = item, properties = !item) ,
  
  readRDS(here::here("Data", "Denominator Neglect", "dn_dat_s.rds")) |>
    select(item, choice_type, proportion_difference, small_lottery_gold_prop) |>
    unique() |>
    left_join(readRDS(here::here("Data", "Denominator Neglect", "top_items_s.rds")) |> 
                select(rank, conf_item, harm_item) |>
                pivot_longer(c(conf_item, harm_item),
                             names_to = "choice_type", values_to = "item") |>
                select(-choice_type),
              by = "item") |>
    mutate(item_ct = paste0("dn.s_", rank, "_", choice_type),
           item = paste0("dn.s_", rank),
           .keep = "unused") |>
    nest(.by = item, properties = !item),
  
  readRDS((here::here("Data", "Bayesian Update", "bu_dat_e.rds"))) |>
    mutate(item = paste0("bu.e_", unique_trial)) |>
    select(item, ball_split) |>
    unique() |>
    nest(properties = !c(item)),
  
  readRDS((here::here("Data", "Bayesian Update", "bu_dat_h.rds"))) |>
    mutate(item = paste0("bu.h_", unique_trial)) |>
    select(item, ball_split) |>
    unique() |>
    nest(properties = !c(item)),
  
  readRDS(here::here("Data", "ADMC Decision Rules", "admc.dr_dat.rds")) |>
    mutate(item = paste0("admc.dr_", parse_number(admc_id))) |>
    select(item, n_correct_options) |>
    unique() |>
    nest(properties = !c(item)),
  
  readRDS(here::here("Data", "Time Series", "ts_dat.rds")) |>
    mutate(item = paste0("ts_", ts_id)) |>
    select(item, noise = noise_condition, datapoints, func, direction) |>
    unique() |>
    nest(properties = !c(item)),
  
  readRDS(here::here("Data", "Number Series", "ns_dat.rds")) |>
    mutate(number_series = recode_values(parse_number(ns_id),
                                         1 ~ "10, 4, ____, -8, -14, -20",
                                         2 ~ "3, 6, 10, 15, 21, ____",
                                         3 ~ "121, 100, 81, ____, 49",
                                         4 ~ "3, 10, 16, 23, ____, 36",
                                         5 ~ "3/21, ____, 13/11, 18/6, 23/1, 28/-4",
                                         6 ~ "200, 198, 192, 174, ____",
                                         7 ~ "3, 2, 10, 4, 19, 6, 30, 8, ____",
                                         8 ~ "10000, 9000, ____, 8890, 8889",
                                         9 ~ "3/4, 4/6, 6/8, 8/12, ____")) |>
    select(item = ns_id, number_series) |>
    unique() |>
    nest(properties = !c(item))
)
Code
# methods |>
#   map(\(method) tibble(item = sym(paste0(method, "_test")) |> eval()) |>
#          left_join(item_properties,
#                    by = "item") |>
#          arrange(item) |>
#          pmap(\(item, properties) tibble(item, properties) |>
#                 unnest(cols = c()) |>
#                 mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
#                        .keep = "unused") |>
#                 nest(properties = !item)) |>
#          list_rbind() |>
#          gt() |>
#          tab_header(method)) 

tibble(item = sym(paste0("v5", "_test")) |> eval()) |>
         left_join(item_properties,
                   by = "item") |>
         arrange(item) |>
         pmap(\(item, properties) tibble(item, properties) |>
                unnest(cols = c()) |>
                mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
                       .keep = "unused") |>
                nest(properties = !item)) |>
         list_rbind() |>
         gt() |>
         tab_header("Search Algorithm")
tibble(item = sym(paste0("lasso", "_test")) |> eval()) |>
         left_join(item_properties,
                   by = "item") |>
         arrange(item) |>
         pmap(\(item, properties) tibble(item, properties) |>
                unnest(cols = c()) |>
                mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
                       .keep = "unused") |>
                nest(properties = !item)) |>
         list_rbind() |>
         gt() |>
         tab_header("LASSO")
tibble(item = sym(paste0("xgb", "_test")) |> eval()) |>
         left_join(item_properties,
                   by = "item") |>
         arrange(item) |>
         pmap(\(item, properties) tibble(item, properties) |>
                unnest(cols = c()) |>
                mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
                       .keep = "unused") |>
                nest(properties = !item)) |>
         list_rbind() |>
         gt() |>
         tab_header("xgboost")
Search Algorithm
item properties
admc.dr_1 1
admc.dr_2 1
admc.dr_3 1
admc.dr_4 1
admc.dr_6 1
admc.dr_8 2
admc.dr_9 3
bu.e_0 40,60
bu.e_2 30,70
dn.c_1_conflict 0.08, 0.1
dn.c_1_harmony 0.07, 0.3
dn.c_10_conflict 0.05, 0.2
dn.c_10_harmony 0.06, 0.1
dn.c_12_conflict 0.01, 0.1
dn.c_12_harmony 0.04, 0.2
dn.c_13_conflict 0.04, 0.2
dn.c_13_harmony 0.05, 0.1
dn.c_14_conflict 0.02, 0.3
dn.c_14_harmony 0.01, 0.1
dn.c_16_conflict 0.01, 0.2
dn.c_16_harmony 0.08, 0.1
dn.c_18_conflict 0.05, 0.1
dn.c_18_harmony 0.05, 0.3
dn.c_19_conflict 0.02, 0.2
dn.c_19_harmony 0.08, 0.3
dn.c_23_conflict 0.07, 0.2
dn.c_23_harmony 0.02, 0.3
dn.c_24_conflict 0.08, 0.2
dn.c_24_harmony 0.01, 0.3
ns_5 3/21, ____, 13/11, 18/6, 23/1, 28/-4
ts_2 low, datapoints_30, linear, negative
ts_7 high, datapoints_30, linear, negative
ts_9 low, datapoints_10, linear, positive
LASSO
item properties
admc.dr_2 1
admc.dr_8 2
admc.dr_9 3
bu.e_0 40,60
bu.e_2 30,70
bu.e_6 40,60
bu.h_0 40,60
bu.h_2 40,60
bu.h_4 30,70
bu.h_6 40,60
dn.c_1_conflict 0.08, 0.1
dn.c_1_harmony 0.07, 0.3
dn.c_16_conflict 0.01, 0.2
dn.c_16_harmony 0.08, 0.1
dn.c_19_conflict 0.02, 0.2
dn.c_19_harmony 0.08, 0.3
ts_7 high, datapoints_30, linear, negative
ts_8 high, datapoints_10, linear, negative
ts_9 low, datapoints_10, linear, positive
xgboost
item properties
admc.dr_2 1
admc.dr_4 1
admc.dr_9 3
bu.e_0 40,60
bu.e_2 30,70
bu.e_4 30,70
bu.e_6 40,60
bu.e_7 30,70
bu.h_1 30,70
bu.h_2 40,60
bu.h_3 40,60
bu.h_5 30,70
bu.h_6 40,60
ts_2 low, datapoints_30, linear, negative
ts_7 high, datapoints_30, linear, negative