Code
library(tidyverse)
library(gt)
set_theme(theme_classic(base_size = 16, paper = "#eceadf"))Search algorithm, LASSO regression, and {xgboost}
Jessica Helmer
April 2, 2026
v5_dat <- readRDS(here::here("Data", "v5_dat.rds"))
v5 <- readRDS(here::here("Data", "Search Algorithms", "v5test_dat.rds"))
xgb <- readRDS(here::here("Data", "xgboost", "xgb_item-importance.rds"))
lasso <- readRDS(here::here("Data", "lasso_regression.rds"))
startup_costs <- readRDS(here::here("Data", "startup_costs.rds"))
if (!dir.exists(here::here("Data", "Method Comparison"))) dir.create(here::here("Data", "Method Comparison"))methods <- list("v5", "lasso", "xgb")
v5_ranked <- v5 |>
select(item, v5_rank = rep) |>
arrange(v5_rank)
xgb_ranked <- xgb |>
# because higher importance should correspond to lower (better) rank
# adding one to match v5
mutate(xgb_rank = n() - rank(Importance) + 1) |>
select(xgb_rank, item = Variable) |>
arrange(xgb_rank)
lasso_ranked <- lasso |>
mutate(lasso_rank = n() - rank(Importance) + 1) |>
select(lasso_rank, item = Variable) |>
arrange(lasso_rank)
ranks <- v5_ranked |>
full_join(xgb_ranked,
by = "item") |>
full_join(lasso_ranked,
by = "item") |>
inner_join(v5_dat |>
select(task, item, time) |>
unique(),
by = "item")Correlations between item ranks from each method.
Spearman Rank
ranks |>
select(`Search Algorithm` = v5_rank,
`xgboost` = xgb_rank,
LASSO = lasso_rank) |>
cor(method = "spearman") |>
as.data.frame() |>
rownames_to_column() |>
gt::gt() |>
gt::fmt_number() |>
tab_style(style = list(
cell_fill(color = "#eceadf"),
cell_borders(sides = c("left", "right", "top", "bottom"),
color = scales::col_darker("#eceadf", 15))),
locations = list(cells_body(), cells_column_labels(), cells_title(), cells_stub(), cells_stubhead()))| Search Algorithm | xgboost | LASSO | |
|---|---|---|---|
| Search Algorithm | 1.00 | 0.23 | 0.49 |
| xgboost | 0.23 | 1.00 | 0.40 |
| LASSO | 0.49 | 0.40 | 1.00 |
Kendall’s Tau
ranks |>
select(`Search Algorithm` = v5_rank,
`xgboost` = xgb_rank,
LASSO = lasso_rank) |>
cor(method = "kendall") |>
as.data.frame() |>
rownames_to_column() |>
gt::gt() |>
gt::fmt_number() |>
tab_style(style = list(
cell_fill(color = "#eceadf"),
cell_borders(sides = c("left", "right", "top", "bottom"),
color = scales::col_darker("#eceadf", 15))),
locations = list(cells_body(), cells_column_labels(), cells_title(), cells_stub(), cells_stubhead()))| Search Algorithm | xgboost | LASSO | |
|---|---|---|---|
| Search Algorithm | 1.00 | 0.14 | 0.34 |
| xgboost | 0.14 | 1.00 | 0.27 |
| LASSO | 0.34 | 0.27 | 1.00 |
ranks |>
mutate(task = str_split_i(item, "_", 1)) |>
ggplot(aes(x = v5_rank, y = xgb_rank, label = item, color = task)) +
geom_text()
ranks |>
mutate(task = str_split_i(item, "_", 1)) |>
ggplot(aes(x = v5_rank, y = lasso_rank, label = item, color = task)) +
geom_text()
ranks |>
mutate(task = str_split_i(item, "_", 1)) |>
ggplot(aes(x = lasso_rank, y = xgb_rank, label = item, color = task)) +
geom_text()walk(methods, \(x) ranks |>
mutate(task = str_split_i(item, "_", 1)) |>
arrange(paste0(x, "_rank") |> sym() |> eval()) |>
mutate(.by = task,
time_w_startup = time + ifelse(item == first(item), get_startup(first(task)), 0)) |>
mutate(cum_time = cumsum(time_w_startup)) |>
assign(x = paste0(x, "_ranks_w_timing"), value = _, env = globalenv()))v5_r2_w_time <- map(1:99,
\(i) {
included_items <- v5_ranks_w_timing |>
select(item, task, v5_rank, cum_time) |>
head(i)
inner_join(v5_dat|>
select(subject_id, sscore, item, score),
included_items,
by = "item") |>
summarize(.by = c(subject_id, task),
sscore = first(sscore),
score = mean(score)) |>
mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
pivot_wider(names_from = task, values_from = score) |>
drop_na() |>
summarize(rep = i,
item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
test_r2 = summary(lm(as.formula(first(formula)),
data = pick(everything())))$r.squared,
cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
list_rbind()
saveRDS(v5_r2_w_time, here::here("Data", "Method Comparison", "v5_r2_w_time.rds"))xgb_r2_w_time <- map(1:99,
\(i) {
included_items <- xgb_ranks_w_timing |>
select(item, task, v5_rank, cum_time) |>
head(i)
inner_join(v5_dat |>
select(subject_id, sscore, item, score),
included_items,
by = "item") |>
summarize(.by = c(subject_id, task),
sscore = first(sscore),
score = mean(score)) |>
mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
pivot_wider(names_from = task, values_from = score) |>
drop_na() |>
summarize(rep = i,
item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
test_r2 = summary(lm(as.formula(first(formula)),
data = pick(everything())))$r.squared,
cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
list_rbind()
saveRDS(xgb_r2_w_time, here::here("Data", "Method Comparison", "xgb_r2_w_time.rds"))lasso_r2_w_time <- map(1:99,
\(i) {
included_items <- lasso_ranks_w_timing |>
select(item, task, v5_rank, cum_time) |>
head(i)
inner_join(v5_dat|>
select(subject_id, sscore, item, score),
included_items,
by = "item") |>
summarize(.by = c(subject_id, task),
sscore = first(sscore),
score = mean(score)) |>
mutate(formula = paste("sscore ~", task |> unique() |> paste(collapse = "+"))) |>
pivot_wider(names_from = task, values_from = score) |>
drop_na() |>
summarize(rep = i,
item = included_items |> filter(cum_time == max(cum_time)) |> pull(item),
test_r2 = summary(lm(as.formula(first(formula)),
data = pick(everything())))$r.squared,
cum_time = included_items |> filter(cum_time == max(cum_time)) |> pull(cum_time))}) |>
list_rbind()
saveRDS(lasso_r2_w_time, here::here("Data", "Method Comparison", "lasso_r2_w_time.rds"))v5_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "v5_r2_w_time.rds"))
xgb_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "xgb_r2_w_time.rds"))
lasso_r2_w_time <- readRDS(here::here("Data", "Method Comparison", "lasso_r2_w_time.rds"))
walk2(list(v5_r2_w_time, xgb_r2_w_time, lasso_r2_w_time),
list("Search Algorithm", "xgboost", "LASSO"),
\(d, name) (d |>
mutate(task = str_split_i(item, "_", 1)) |>
ggplot(aes(x = cum_time, y = test_r2)) +
geom_line() +
geom_point(aes(color = task)) +
coord_cartesian(ylim = c(.1, .5)) +
labs(title = name, y = "R<sup>2</sup>", x = "Time") +
theme(axis.title.y = ggtext::element_markdown())) |>
print())v5_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "v5") |>
full_join(xgb_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "xgb"),
by = "task") |>
full_join(lasso_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "lasso"),
by = "task") |>
arrange(task) |>
gt(rowname_col = "task") |>
sub_missing(missing_text = "") |>
tab_style(style = cell_text(weight = "bold"),
locations = cells_column_labels())
v5_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "v5") |>
full_join(xgb_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "xgb"),
by = "task") |>
full_join(lasso_ranks_w_timing |>
filter(cum_time <= 15) |>
count(task, name = "lasso"),
by = "task") |>
mutate(task = str_split_i(task, "\\.", 1)) |>
summarize(.by = task,
across(everything(), ~ sum(.x, na.rm = T))) |>
arrange(task) |>
gt(rowname_col = "task") |>
sub_missing(missing_text = "") |>
sub_zero(zero_text = " ") |>
tab_style(style = cell_text(weight = "bold"),
locations = cells_column_labels())| v5 | xgb | lasso | |
|---|---|---|---|
| admc.dr | 7 | 3 | 3 |
| bu.e | 2 | 5 | 3 |
| bu.h | 5 | 4 | |
| dn.c | 10 | 3 | |
| ns | 1 | ||
| ts | 3 | 2 | 3 |
| v5 | xgb | lasso | |
|---|---|---|---|
| admc | 7 | 3 | 3 |
| bu | 2 | 10 | 7 |
| dn | 10 | 3 | |
| ns | 1 | ||
| ts | 3 | 2 | 3 |
walk(methods, \(x) x |>
paste0("_ranks_w_timing") |>
sym() |>
eval() |>
filter(cum_time <= 15) |>
pull(item) |>
assign(x = paste0(x, "_test"),
value = _,
env = globalenv()))
ranks |>
pull(item) |>
map(\(item) tibble(item = item,
count = sum(c(item %in% v5_test,
item %in% xgb_test,
item %in% lasso_test)))) |>
bind_rows() |>
filter_out(count == 0) |>
arrange(desc(count)) |>
mutate(task = str_split_i(item, "_", 1),
.before = item) |>
gt(rowname_col = c("task", "item")) |>
tab_style(style = cell_text(weight = "bold"),
locations = cells_column_labels())| count | ||
|---|---|---|
| admc.dr | admc.dr_2 | 3 |
| admc.dr_9 | 3 | |
| ts | ts_7 | 3 |
| bu.e | bu.e_0 | 3 |
| bu.e_2 | 3 | |
| admc.dr | admc.dr_4 | 2 |
| dn.c | dn.c_1 | 2 |
| dn.c_16 | 2 | |
| dn.c_19 | 2 | |
| ts | ts_9 | 2 |
| ts_2 | 2 | |
| admc.dr | admc.dr_8 | 2 |
| bu.h | bu.h_6 | 2 |
| bu.h_2 | 2 | |
| bu.e | bu.e_6 | 2 |
| dn.c | dn.c_13 | 1 |
| dn.c_10 | 1 | |
| dn.c_24 | 1 | |
| dn.c_12 | 1 | |
| dn.c_18 | 1 | |
| dn.c_14 | 1 | |
| admc.dr | admc.dr_1 | 1 |
| admc.dr_6 | 1 | |
| admc.dr_3 | 1 | |
| ns | ns_5 | 1 |
| dn.c | dn.c_23 | 1 |
| bu.e | bu.e_4 | 1 |
| bu.h | bu.h_5 | 1 |
| bu.h_4 | 1 | |
| bu.h_1 | 1 | |
| bu.h_3 | 1 | |
| bu.h_0 | 1 | |
| bu.e | bu.e_7 | 1 |
| ts | ts_8 | 1 |
item_properties <- rbind(
readRDS(here::here("Data", "Denominator Neglect", "dn_dat_c.rds")) |>
select(item, choice_type, proportion_difference, small_lottery_gold_prop) |>
unique() |>
left_join(readRDS(here::here("Data", "Denominator Neglect", "top_items.rds")) |>
select(rank, conf_item, harm_item) |>
pivot_longer(c(conf_item, harm_item),
names_to = "choice_type", values_to = "item") |>
select(-choice_type),
by = "item") |>
mutate(item_ct = paste0("dn.c_", rank, "_", choice_type),
item = paste0("dn.c_", rank),
.keep = "unused") |>
nest(.by = item, properties = !item) ,
readRDS(here::here("Data", "Denominator Neglect", "dn_dat_s.rds")) |>
select(item, choice_type, proportion_difference, small_lottery_gold_prop) |>
unique() |>
left_join(readRDS(here::here("Data", "Denominator Neglect", "top_items_s.rds")) |>
select(rank, conf_item, harm_item) |>
pivot_longer(c(conf_item, harm_item),
names_to = "choice_type", values_to = "item") |>
select(-choice_type),
by = "item") |>
mutate(item_ct = paste0("dn.s_", rank, "_", choice_type),
item = paste0("dn.s_", rank),
.keep = "unused") |>
nest(.by = item, properties = !item),
readRDS((here::here("Data", "Bayesian Update", "bu_dat_e.rds"))) |>
mutate(item = paste0("bu.e_", unique_trial)) |>
select(item, ball_split) |>
unique() |>
nest(properties = !c(item)),
readRDS((here::here("Data", "Bayesian Update", "bu_dat_h.rds"))) |>
mutate(item = paste0("bu.h_", unique_trial)) |>
select(item, ball_split) |>
unique() |>
nest(properties = !c(item)),
readRDS(here::here("Data", "ADMC Decision Rules", "admc.dr_dat.rds")) |>
mutate(item = paste0("admc.dr_", parse_number(admc_id))) |>
select(item, n_correct_options) |>
unique() |>
nest(properties = !c(item)),
readRDS(here::here("Data", "Time Series", "ts_dat.rds")) |>
mutate(item = paste0("ts_", ts_id)) |>
select(item, noise = noise_condition, datapoints, func, direction) |>
unique() |>
nest(properties = !c(item)),
readRDS(here::here("Data", "Number Series", "ns_dat.rds")) |>
mutate(number_series = recode_values(parse_number(ns_id),
1 ~ "10, 4, ____, -8, -14, -20",
2 ~ "3, 6, 10, 15, 21, ____",
3 ~ "121, 100, 81, ____, 49",
4 ~ "3, 10, 16, 23, ____, 36",
5 ~ "3/21, ____, 13/11, 18/6, 23/1, 28/-4",
6 ~ "200, 198, 192, 174, ____",
7 ~ "3, 2, 10, 4, 19, 6, 30, 8, ____",
8 ~ "10000, 9000, ____, 8890, 8889",
9 ~ "3/4, 4/6, 6/8, 8/12, ____")) |>
select(item = ns_id, number_series) |>
unique() |>
nest(properties = !c(item))
)# methods |>
# map(\(method) tibble(item = sym(paste0(method, "_test")) |> eval()) |>
# left_join(item_properties,
# by = "item") |>
# arrange(item) |>
# pmap(\(item, properties) tibble(item, properties) |>
# unnest(cols = c()) |>
# mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
# .keep = "unused") |>
# nest(properties = !item)) |>
# list_rbind() |>
# gt() |>
# tab_header(method))
tibble(item = sym(paste0("v5", "_test")) |> eval()) |>
left_join(item_properties,
by = "item") |>
arrange(item) |>
pmap(\(item, properties) tibble(item, properties) |>
unnest(cols = c()) |>
mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
.keep = "unused") |>
nest(properties = !item)) |>
list_rbind() |>
gt() |>
tab_header("Search Algorithm")
tibble(item = sym(paste0("lasso", "_test")) |> eval()) |>
left_join(item_properties,
by = "item") |>
arrange(item) |>
pmap(\(item, properties) tibble(item, properties) |>
unnest(cols = c()) |>
mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
.keep = "unused") |>
nest(properties = !item)) |>
list_rbind() |>
gt() |>
tab_header("LASSO")
tibble(item = sym(paste0("xgb", "_test")) |> eval()) |>
left_join(item_properties,
by = "item") |>
arrange(item) |>
pmap(\(item, properties) tibble(item, properties) |>
unnest(cols = c()) |>
mutate(item = ifelse(str_detect(item, "dn"), item_ct, item),
.keep = "unused") |>
nest(properties = !item)) |>
list_rbind() |>
gt() |>
tab_header("xgboost")| Search Algorithm | |
| item | properties |
|---|---|
| admc.dr_1 | 1 |
| admc.dr_2 | 1 |
| admc.dr_3 | 1 |
| admc.dr_4 | 1 |
| admc.dr_6 | 1 |
| admc.dr_8 | 2 |
| admc.dr_9 | 3 |
| bu.e_0 | 40,60 |
| bu.e_2 | 30,70 |
| dn.c_1_conflict | 0.08, 0.1 |
| dn.c_1_harmony | 0.07, 0.3 |
| dn.c_10_conflict | 0.05, 0.2 |
| dn.c_10_harmony | 0.06, 0.1 |
| dn.c_12_conflict | 0.01, 0.1 |
| dn.c_12_harmony | 0.04, 0.2 |
| dn.c_13_conflict | 0.04, 0.2 |
| dn.c_13_harmony | 0.05, 0.1 |
| dn.c_14_conflict | 0.02, 0.3 |
| dn.c_14_harmony | 0.01, 0.1 |
| dn.c_16_conflict | 0.01, 0.2 |
| dn.c_16_harmony | 0.08, 0.1 |
| dn.c_18_conflict | 0.05, 0.1 |
| dn.c_18_harmony | 0.05, 0.3 |
| dn.c_19_conflict | 0.02, 0.2 |
| dn.c_19_harmony | 0.08, 0.3 |
| dn.c_23_conflict | 0.07, 0.2 |
| dn.c_23_harmony | 0.02, 0.3 |
| dn.c_24_conflict | 0.08, 0.2 |
| dn.c_24_harmony | 0.01, 0.3 |
| ns_5 | 3/21, ____, 13/11, 18/6, 23/1, 28/-4 |
| ts_2 | low, datapoints_30, linear, negative |
| ts_7 | high, datapoints_30, linear, negative |
| ts_9 | low, datapoints_10, linear, positive |
| LASSO | |
| item | properties |
|---|---|
| admc.dr_2 | 1 |
| admc.dr_8 | 2 |
| admc.dr_9 | 3 |
| bu.e_0 | 40,60 |
| bu.e_2 | 30,70 |
| bu.e_6 | 40,60 |
| bu.h_0 | 40,60 |
| bu.h_2 | 40,60 |
| bu.h_4 | 30,70 |
| bu.h_6 | 40,60 |
| dn.c_1_conflict | 0.08, 0.1 |
| dn.c_1_harmony | 0.07, 0.3 |
| dn.c_16_conflict | 0.01, 0.2 |
| dn.c_16_harmony | 0.08, 0.1 |
| dn.c_19_conflict | 0.02, 0.2 |
| dn.c_19_harmony | 0.08, 0.3 |
| ts_7 | high, datapoints_30, linear, negative |
| ts_8 | high, datapoints_10, linear, negative |
| ts_9 | low, datapoints_10, linear, positive |
| xgboost | |
| item | properties |
|---|---|
| admc.dr_2 | 1 |
| admc.dr_4 | 1 |
| admc.dr_9 | 3 |
| bu.e_0 | 40,60 |
| bu.e_2 | 30,70 |
| bu.e_4 | 30,70 |
| bu.e_6 | 40,60 |
| bu.e_7 | 30,70 |
| bu.h_1 | 30,70 |
| bu.h_2 | 40,60 |
| bu.h_3 | 40,60 |
| bu.h_5 | 30,70 |
| bu.h_6 | 40,60 |
| ts_2 | low, datapoints_30, linear, negative |
| ts_7 | high, datapoints_30, linear, negative |