Skip to content

Commit 6c86653

Browse files
committed
Update v02_Evaluation_using_Trio.Rmd
1 parent bdc5bce commit 6c86653

1 file changed

Lines changed: 73 additions & 53 deletions

File tree

vignettes/v02_Evaluation_using_Trio.Rmd

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -77,74 +77,94 @@ Using the `CVindices`, user can subset to training and test data. As an example,
7777

7878
Once predictions are obtained for the test set, we pass them to Trio using the same evidence name as stored in the Trio object (i.e., `Diagnosis`). Specifically, we call `trio$evaluate(list(lasso = list(Diagnosis = pred)))`, which instructs `evaluate()` to compare `pred` against the reference labels `Diagnosis` stored in `trio`, and then compute the specified metric (Balanced Accuracy).
7979

80+
We first construct an explicit cross-validation plan that records fold and repeat identifiers, and then iterate over this plan to evaluate each split.
81+
8082
```{r}
8183
set.seed(1234)
8284
83-
# Loop through 2 folds x 5 repeats = 10 runs
84-
result <- do.call(
85-
rbind,
86-
mapply(
87-
function(trainIDs, crossValID) {
88-
x_train <- x[trainIDs, ]
89-
x_test <- x[-trainIDs, ]
90-
y_train <- y[trainIDs]
91-
y_test <- y[-trainIDs]
92-
93-
# Find the best lambda for LASSO regression
94-
cv_lasso <- cv.glmnet(
95-
x = as.matrix(x_train),
96-
y = y_train,
97-
alpha = 1,
98-
family = "binomial"
99-
)
100-
lam <- cv_lasso$lambda.1se
101-
102-
# Fit a model with the best lambda on training data
103-
fit <- glmnet(
104-
x = as.matrix(x_train),
105-
y = y_train,
106-
alpha = 1,
107-
lambda = lam,
108-
family = "binomial"
109-
)
110-
111-
# Evaluate the model on test data
112-
pred <- predict(
113-
fit,
114-
newx = as.matrix(x_test),
115-
s = lam,
116-
type = "class"
117-
)
118-
pred <- setNames(as.factor(as.vector(pred)), rownames(x_test))
119-
120-
# Get the chosen evaluation metric from the Trio
121-
eval_res <- trio$evaluate(list(lasso = list(Diagnosis = pred)))
122-
123-
# Keep track of the repeat and fold information
124-
eval_res$track <- crossValID
125-
eval_res
126-
},
127-
CVindices,
128-
names(CVindices),
129-
SIMPLIFY = FALSE
85+
cv_plan <- tibble::tibble(
86+
trainIDs = CVindices,
87+
fold = vapply(
88+
strsplit(names(CVindices), ".", fixed = TRUE),
89+
`[`,
90+
character(1),
91+
1
92+
),
93+
repeat_id = vapply(
94+
strsplit(names(CVindices), ".", fixed = TRUE),
95+
`[`,
96+
character(1),
97+
2
13098
)
13199
)
132100
101+
102+
run_one_cv <- function(trainIDs, fold, repeat_id, trio, x, y) {
103+
x_train <- x[trainIDs, ]
104+
x_test <- x[-trainIDs, ]
105+
y_train <- y[trainIDs]
106+
y_test <- y[-trainIDs]
107+
108+
cv_lasso <- glmnet::cv.glmnet(
109+
x = as.matrix(x_train),
110+
y = y_train,
111+
alpha = 1,
112+
family = "binomial"
113+
)
114+
lam <- cv_lasso$lambda.1se
115+
116+
fit <- glmnet::glmnet(
117+
x = as.matrix(x_train),
118+
y = y_train,
119+
alpha = 1,
120+
lambda = lam,
121+
family = "binomial"
122+
)
123+
124+
pred <- predict(
125+
fit,
126+
newx = as.matrix(x_test),
127+
s = lam,
128+
type = "class"
129+
)
130+
pred <- setNames(as.factor(as.vector(pred)), rownames(x_test))
131+
132+
eval_res <- trio$evaluate(list(lasso = list(Diagnosis = pred)))
133+
134+
# attach metadata explicitly
135+
eval_res$fold <- fold
136+
eval_res$repeat_id <- repeat_id
137+
138+
eval_res
139+
}
140+
141+
result_list <- vector("list", nrow(cv_plan))
142+
143+
for (i in seq_len(nrow(cv_plan))) {
144+
result_list[[i]] <- run_one_cv(
145+
trainIDs = cv_plan$trainIDs[[i]],
146+
fold = cv_plan$fold[[i]],
147+
repeat_id = cv_plan$repeat_id[[i]],
148+
trio = trio,
149+
x = x,
150+
y = y
151+
)
152+
}
133153
```
134154

135155
After cross-validation, we can visualise cross-validation results by averaging results across folds within each repeats.
136156

137157
```{r fig.cap = "Mean cross-validation accuracy across repeats."}
138-
result$fold <- unlist(lapply(strsplit(result$track, ".", fixed = TRUE), `[`, 1))
139-
result$repeats <- unlist(lapply(strsplit(result$track, ".", fixed = TRUE), `[`, 2))
158+
result <- dplyr::bind_rows(result_list)
159+
160+
result_summary <- result %>%
161+
dplyr::group_by(datasetID, method, evidence, metric, repeat_id) %>%
162+
dplyr::summarise(result = mean(result), .groups = "drop")
140163
141-
result <- result %>%
142-
dplyr::group_by(datasetID, method, evidence, metric, repeats) %>%
143-
dplyr::summarize(result = mean(result))
144164
145165
# visualise the result
146166
boxplot(
147-
result$result,
167+
result_summary$result,
148168
ylab = "Accuracy",
149169
main = "Cross-validation performance"
150170
)

0 commit comments

Comments
 (0)