plot shapley values -凯发k8网页登录
plot shapley values
since r2021a
description
plot(
creates a horizontal bar graph
of the shapley values of the explainer
)shapley
object
explainer
. these values are stored in the object's shapleyvalues
property. each bar shows the shapley value of each feature in the blackbox model
(explainer.
) for the query point (blackboxmodel
explainer.
).querypoint
plot(
specifies additional options using one or more name-value arguments. for example, specify
explainer
,name,value
)'numimportantpredictors',5
to plot the shapley values of the five
features with the highest absolute shapley values.
b = plot(___)
returns a bar graph object
b
using any of the input argument combinations in the previous
syntaxes. use b
to query or modify of the bar graph after it is created.
examples
plot shapley values for all classes
train a classification model and create a shapley
object. then plot the shapley values by using the object function plot
.
load the creditrating_historical
data set. the data set contains customer ids and their financial ratios, industry labels, and credit ratings.
tbl = readtable('creditrating_historical.dat');
display the first three rows of the table.
head(tbl,3)
id wc_ta re_ta ebit_ta mve_bvtd s_ta industry rating _____ _____ _____ _______ ________ _____ ________ ______ 62394 0.013 0.104 0.036 0.447 0.142 3 {'bb'} 48608 0.232 0.335 0.062 1.969 0.281 8 {'a' } 42444 0.311 0.367 0.074 1.935 0.366 1 {'a' }
train a blackbox model of credit ratings by using the fitcecoc
function. use the variables from the second through seventh columns in tbl
as the predictor variables. a recommended practice is to specify the class names to set the order of the classes.
blackbox = fitcecoc(tbl,'rating', ... 'predictornames',tbl.properties.variablenames(2:7), ... 'categoricalpredictors','industry', ... 'classnames',{'aaa' 'aa' 'a' 'bbb' 'bb' 'b' 'ccc'});
create a shapley
object that explains the prediction for the last observation. for faster computation, subsample 25% of the observations from tbl
with stratification and use the samples to compute the shapley values.
querypoint = tbl(end,:)
querypoint=1×8 table
id wc_ta re_ta ebit_ta mve_bvtd s_ta industry rating
_____ _____ _____ _______ ________ ____ ________ ______
73104 0.239 0.463 0.065 2.924 0.34 2 {'aa'}
rng('default') % for reproducibility c = cvpartition(tbl.rating,'holdout',0.25); tbl_s = tbl(test(c),:); explainer = shapley(blackbox,tbl_s,'querypoint',querypoint);
for a classification model, shapley
computes shapley values using the predicted class score for each class. display the values in the shapleyvalues
property.
explainer.shapleyvalues
ans=6×8 table
predictor aaa aa a bbb bb b ccc
__________ _________ __________ ___________ __________ ___________ __________ __________
"wc_ta" 0.051045 0.022644 0.0096138 0.0015954 -0.027857 -0.04134 -0.039476
"re_ta" 0.16729 0.09479 0.05308 -0.011178 -0.087689 -0.20847 -0.29204
"ebit_ta" 0.0012015 0.00053338 0.00043344 0.00012321 -0.00066994 -0.0013388 -0.0011793
"mve_bvtd" 1.3377 1.338 0.67839 -0.027654 -0.55142 -0.75327 -0.59578
"s_ta" -0.012484 -0.009098 -0.00074119 -0.0035582 -7.3462e-05 0.0014495 -0.0020609
"industry" -0.099117 -0.046867 0.0031376 0.080071 0.089726 0.099699 0.15691
the shapleyvalues
property contains the shapley values of all features for each class.
plot the shapley values for the predicted class by using the plot
function.
plot(explainer)
the horizontal bar graph shows the shapley values for all variables, sorted by their absolute values. each shapley value explains the deviation of the score for the query point from the average score of the predicted class, due to the corresponding variable.
plot the shapley values for all classes by specifying all class names in explainer.blackboxmodel
.
plot(explainer,'classnames',explainer.blackboxmodel.classnames)
specify number of important predictors to plot
train a regression model and create a shapley
object. use the object function fit
to compute the shapley values for the specified query point. then plot the shapley values of the predictors by using the object function plot
. specify the number of important predictors to plot when you call the plot
function.
load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s.
load carbig
create a table containing the predictor variables acceleration
, cylinders
, and so on, as well as the response variable mpg
.
tbl = table(acceleration,cylinders,displacement,horsepower,model_year,weight,mpg);
removing missing values in a training set can help reduce memory consumption and speed up training for the fitrkernel
function. remove missing values in tbl
.
tbl = rmmissing(tbl);
train a blackbox model of mpg
by using the function
rng('default') % for reproducibility mdl = fitrkernel(tbl,'mpg','categoricalpredictors',[2 5]);
create a shapley
object. specify the data set tbl
, because mdl
does not contain training data.
explainer = shapley(mdl,tbl)
explainer = shapley with properties: blackboxmodel: [1x1 regressionkernel] querypoint: [] blackboxfitted: [] shapleyvalues: [] numsubsets: 64 x: [392x7 table] categoricalpredictors: [2 5] method: 'interventional-kernel' intercept: 22.6202
explainer
stores the training data tbl
in the x
property.
compute the shapley values of all predictor variables for the first observation in tbl
.
querypoint = tbl(1,:)
querypoint=1×7 table
acceleration cylinders displacement horsepower model_year weight mpg
____________ _________ ____________ __________ __________ ______ ___
12 8 307 130 70 3504 18
explainer = fit(explainer,querypoint);
for a regression model, shapley
computes shapley values using the predicted response, and stores them in the shapleyvalues
property. display the values in the shapleyvalues
property.
explainer.shapleyvalues
ans=6×2 table
predictor shapleyvalue
______________ ____________
"acceleration" -0.1561
"cylinders" -0.18306
"displacement" -0.34203
"horsepower" -0.27291
"model_year" -0.2926
"weight" -0.32402
plot the shapley values for the query point by using the plot
function. specify 'numimportantpredictors',5
to plot only the five most important predictors for the predicted response.
plot(explainer,'numimportantpredictors',5)
the horizontal bar graph shows the shapley values for the five most important predictors, sorted by their absolute values. each shapley value explains the deviation of the prediction for the query point from the average, due to the corresponding variable.
input arguments
explainer
— object explaining blackbox model
shapley
object
object explaining the blackbox model, specified as a shapley
object.
name-value arguments
specify optional pairs of arguments as
name1=value1,...,namen=valuen
, where name
is
the argument name and value
is the corresponding value.
name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
before r2021a, use commas to separate each name and value, and enclose
name
in quotes.
example: plot(explainer,'numimportantpredictors',5,'classnames',c)
creates a bar graph containing the shapley values of the five most important predictors for
the class c
.
numimportantpredictors
— number of important predictors to plot
min(m,10)
where m
is the
number of predictors (default) | positive integer
number of important predictors to plot, specified as a positive integer. the
plot
function plots the shapley values of the specified number
of predictors with the highest absolute shapley values.
example: 'numimportantpredictors',5
specifies to plot the five
most important predictors. the plot
function determines the order
of importance by using the absolute shapley values.
data types: single
| double
classnames
— class labels to plot
explainer.blackboxfitted
(default) | categorical array | character array | logical vector | numeric vector | cell array of character vectors
class labels to plot, specified as a categorical or character array, logical or
numeric vector, or cell array of character vectors. the values and data types in the
'classnames'
value must match those of the class names in the
classnames
property of the machine learning model in
explainer
(explainer.blackboxmodel.classnames
).
you can specify one or more labels. if you specify multiple class labels, the function plots multiple bars for each feature with different colors.
the default value is the predicted class for the query point (the blackboxfitted
property of explainer
).
this argument is valid only when the machine learning model (blackboxmodel
) in explainer
is a classification
model.
example: 'classnames',{'red','blue'}
example: 'classnames',explainer.blackboxmodel.classnames
specifies 'classnames'
as all classes in
blackboxmodel
.
data types: single
| double
| logical
| char
| cell
| categorical
more about
shapley values
in game theory, the shapley value of a player is the average marginal contribution of the player in a cooperative game. in the context of machine learning prediction, the shapley value of a feature for a query point explains the contribution of the feature to a prediction (response for regression or score of each class for classification) at the specified query point.
the shapley value of a feature for a query point is the contribution of the feature to the deviation from the average prediction. for a query point, the sum of the shapley values for all features corresponds to the total deviation of the prediction from the average. that is, the sum of the average prediction and the shapley values for all features corresponds to the prediction for the query point.
for more details, see shapley values for machine learning model.
references
[1] lundberg, scott m., and s. lee. "a unified approach to interpreting model predictions." advances in neural information processing systems 30 (2017): 4765–774.
[2] aas, kjersti, martin jullum, and anders løland. "explaining individual predictions when features are dependent: more accurate approximations to shapley values." artificial intelligence 298 (september 2021).
[3] lundberg, scott m., g. erion, h. chen, et al. "from local explanations to global understanding with explainable ai for trees." nature machine intelligence 2 (january 2020): 56–67.
version history
introduced in r2021a
打开示例
您曾对此示例进行过修改。是否要打开带有您的编辑的示例?
matlab 命令
您点击的链接对应于以下 matlab 命令:
请在 matlab 命令行窗口中直接输入以执行命令。web 浏览器不支持 matlab 命令。
select a web site
choose a web site to get translated content where available and see local events and offers. based on your location, we recommend that you select: .
you can also select a web site from the following list:
how to get best site performance
select the china site (in chinese or english) for best site performance. other mathworks country sites are not optimized for visits from your location.
americas
- (español)
- (english)
- (english)
europe
- (english)
- (english)
- (deutsch)
- (español)
- (english)
- (français)
- (english)
- (italiano)
- (english)
- (english)
- (english)
- (deutsch)
- (english)
- (english)
- switzerland
- (english)
asia pacific
- (english)
- (english)
- (english)
- 中国
- (日本語)
- (한국어)