지난 포스팅에서 XAI와 DALEX 패키지에 대해 간단히 소개하고, Regression 문제에서 DALEX 패키지로 변수 중요도 뽑는 방법과 그 원리를 알아보았다.
2021.11.15 - [AI/잡지식] - [R] XIA(eXplainable AI) 패키지 중 DALEX로 변수 중요도 뽑기
이번에는 Classification 문제에서 변수중요도 계산하는 방법을 알아보고자 한다.
이전 포스팅에서 언급했지만, Regression과 Classification을 따로 진행하는 이유는, Regression 문제와 Classification 문제는 사용하는 loss function이 다른데, 그 loss에 따라 explaner 생성 시 y에 들어가는 target의 타입과 prediction 형태가 조금씩 달라야 하기 때문이다.
DALEX :: Variable Importance Measures (Classification) in R
Titanic 데이터를 사용하여 진행하겠다.
라이브러리를 불러오고, DALEX 패키지에 있는 titanic 데이터와, 이미 구현된 model을 불러온다.
library(DALEX)
library(gbm)
data <- DALEX::titanic_imputed
model <- apartments_rf <- archivist::aread("pbiecek/models/08544")
타이타닉 데이터는 2207개의 row, 8개의 열로 이루어져있으며,
성별, 나이, 티켓 등급, 요금 등으로 살아남았는지(survived)를 예측하는 문제이다.
이 데이터는 kaggle에서 다운받아 올 수도 있으나, pbiecek에 구현된 모델이 DALEX 데이터를 기반하여 만든 모델이라 DALEX에서 불러와 주었다. (데이터 자체에 큰 차이는 없다. 컬럼 명이 다 소문자로 바뀌었고, 1,2,3으로 이루어진 PClass가 1st, 2nd, 3rd로 바뀌었고.. 등등)
model은 gbm으로 구현되었다.
이 모델 출처는 이곳이며, 이전에 Regression 포스팅에서 진행했던 apartment 데이터 외에도 HR 데이터, Titanic 데이터 등에 대한 모델들과 explainer, 데이터 등을 가지고 있다.
archivist::aread() 로 불러올 수 있고, 데이터셋 이름의 앞 5글자 정도만 따오면 불러올 수 있다.
참고로 모델 성능은 다음과 같다.
(confuncion matrix와 f1 score의 positive는 1이다.)
이제 loss_function에 따라 explain()과 model_parts()로 변수 중요도 뽑는 과정을 설명하겠다.
DALEX에서 제공하는 loss function은 다음과 같다.
이 중 classification에서는 loss_cross_entropy, loss_accuracy, loss_one_minus_auc를 사용하는데,
loss_cross_entropy는 multi classification 문제에서 사용한다.
먼저, loss_cross_entropy 를 loss function으로 변수 중요도를 계산하고 싶을 때는 다음과 같이 한다.
pred <- function(model, newdata) {
predicted <- data.frame("1" = predict(model, newdata, type = "response"),
"0" = 1-predict(model, newdata, type = "response"))
names(predicted) <- c("1","0")
return(predicted)
}
explainer <- explain(
model = model,
data = data %>% dplyr::select(-survived),
y = as.factor(data$survived),
predict_function = pred,
)
set.seed(1)
vip_ce <- model_parts(explainer = explainer,
loss_function = loss_cross_entropy)
정의한 pred function과 explain()안에 y = as.factor() 를 주요하게 보아야 한다.
cross_entropy loss는 multi classification 문제에 적용하는 loss이므로 각 label별로 확률값을 뽑아주어야 한다.
타이타닉은 survived가 0과 1 뿐이라 따로 두 컬럼으로 만들어주었지만, iris의 경우 setosa, versicolor, virginica 에 따라 각각 확률값을 뽑아주어야 한다. 즉, target의 label 개수 만큼 컬럼이 나오고, 각각에 확률값 형태로 predict function이 구현되어야 한다.
그리고 expain() 안에 y에는 factor 타입으로 넣어주어야 한다.
model_parts()에 loss_function은 loss_cross_entropy로 넣어준다.
결과는 다음과 같다.
gender 변수만 random하게 순서를 섞어 prediction 했을 때 cross entropy가 가장 많이 떨어진다고 해석할 수 있다.
즉, gender 변수가 해당 모델에서는 가장 중요한 변수라는 것을 의미한다.
DALEX에서 변수 중요도를 계산하는 원리는 이전 포스팅에 설명해 놓았다.
다음으로 accuracy를 loss로 설정할 때는 다음과 같이 한다.
pred <- function(model, newdata) {
predicted <- predict(model,newdata = newdata, type = "response")
predicted <- as.factor(ifelse(predicted < 0.5,0,1))
return(predicted)
}
explainer <- explain(
model = model,
data = data %>% dplyr::select(-survived),
y = as.factor(data$survived),
predict_function = pred,
)
set.seed(1)
vip_acc <- model_parts(explainer = explainer,
loss_function = loss_accuracy)
이번에는 predict function을 0과 1로 이루어진 factor 형으로 떨어지도록 구현해야한다.
explain() 에 y값 역시 factor형으로 들어가야 한다.
model_parts()에 loss_function은 loss_accuracy로 넣어준다.
결과는 다음과 같다.
역시 gender 변수의 중요도가 가장 높게 나왔다.
마지막으로 loss_one_minus_auc는 다음과 같다.
pred <- function(model, newdata) {
predicted <- predict(model,newdata = newdata, type = "response")
return(predicted)
}
explainer <- explain(
model = model,
data = data %>% dplyr::select(-survived),
y = data$survived,
predict_function = pred,
)
set.seed(1)
vip_auc <- model_parts(explainer = explainer,
loss_function = loss_one_minus_auc)
이번에 prediction function은 1일 확률값으로 뽑아준다.
explain()에 y는 numeric 타입으로 넣어주어야 한다.
model_parts()에 loss_function은 loss_one_minus_auc로 넣어준다.
결과는 다음과 같다.
AUC는 ROC커브 밑의 넓이를 의미하고, 이는 클 수록 좋은 모델이므로, 1-AUC는 작을수록 좋은 모델이다.
즉, gender 변수만 random하게 순서를 섞은 데이터로 prediction 했을 때 1-AUC가 가장 커졌으므로, 변수 중요도가 가장 높다 할 수 있다.
loss function에 따라 변수 중요도를 해석하는 것이 어려울 수 있는데, plot을 그렸을 때 bar의 길이가 길 수록 중요한 변수라고 해석하면 된다.
'AI > 잡지식' 카테고리의 다른 글
AI(Artificial Intelligence) VS ML(Machine Learning) VS DL(Deep Learning) (1) | 2024.09.02 |
---|---|
[R] XIA(eXplainable AI) 패키지 중 DALEX로 변수 중요도 뽑기 (0) | 2021.11.15 |
넬슨 법칙이란? What is the Nelson Rules? (0) | 2020.02.04 |
KL divergence(Kullback–Leibler) (2) | 2018.09.21 |
curse of dimensionality - 차원의 저주 (2) | 2018.09.20 |