使用各种评分指标评估MySQL HeatWave ML模型
机器学习模型的质量不仅必须在模型开发阶段进行评估,还必须在模型部署到生产中之后进行评估。在开发过程中,模型开发人员必须牢记模型产生的结果的影响,并使用适当的评分指标来揭示潜在的问题。例如,在使用模型检测癌症等疾病的情况下,假阴性和假阳性会产生严重后果,使用适当的评分标准变得至关重要。一旦模型部署到生产中,就必须经常测量模型质量,因为随着时间的推移,由于趋势变化、异常值处理不当和原始培训数据集中的类别不平衡等因素,模型可能会产生次优结果。为了暴露部署模型的模型降级背后的问题,用户需要选择适当的评分指标。
MySQL HeatWave ML支持一系列评分指标来计算模型质量。支持的评分指标可以用于本博客中描述的特定目的。如果模型得分不是最优的,用户可以使用MySQL HeatWave ML中的解释性功能来理解每个功能的影响,并基于此从上下文构建新功能。可以快速准确地重新训练新模型,使用户可以迭代地创建模型。最重要的是,用户不必将数据或模型移出数据库,因为MySQL HeatWave ML提供了真正的数据库机器学习。
评分指标
MySQL HeatWave ML支持各种评分指标,如表1所示。用户需要记住这些指标的特点,以便有效地使用它们。
评分度量“准确性”和“平衡精度”适用于二元分类和多类分类。在二元分类中,预测两类中的一类。在多类分类中,预测两个以上类别中的一个。评分指标“f1”、“精度”、“召回率”只能用于二元分类。然而,通过将问题视为二进制分类问题的集合,每个类一个,这些度量被扩展为支持多类分类。
有几种方法来平均跨类集合的二进制度量计算,
1.宏:计算二进制度量的平均值,为每个类赋予相等的权重。在同等对待所有类别的情况下,应使用宏平均。
2.微:在多类分类中,微平均是首选,其中大多数类需要被忽略,而获得少数类的准确性很重要。
数据集
本文中提供的HeatWave ML评分度量的示例基于,
1.来自scikit learn的虹膜数据集-该数据集由三种虹膜(刚毛虹膜、弗吉尼亚虹膜和花色虹膜)中的每一种的50个样本组成。从每个样本中测量了四个特征:萼片和花瓣的长度和宽度,以厘米为单位。ML任务是基于这四个特征的组合,开发一个ML模型来区分物种。
2.钻石数据集——该数据集包含近54000颗钻石的价格和其他10个属性。ML任务是基于数据集中提供的特征,开发一个ML模型来预测钻石价格。
用户需要创建一个HeatWave集群,并使用上述数据集使用本博客中提供的评分指标尝试评分功能。
分类
基于Iris数据集建立回归模型并加载模型。
CALL sys.ML_TRAIN('Iris.iris_train', 'class',JSON_OBJECT('task', 'classification'), @iris_model);
CALL sys.ML_MODEL_LOAD(@iris_model, NULL);
下面提供的评分度量示例是使用iris_model创建的。
Confusion Matrix
混淆矩阵是用于评估分类模型性能的N x N矩阵,其中N是目标类的数量。该矩阵将实际目标值与机器学习模型预测的值进行比较。这为我们提供了一个关于分类模型性能如何以及它正在产生何种错误的整体视图。本节讨论的分类度量可以使用混淆矩阵表示。
为了用一个例子解释下面给出的混淆矩阵,让我们考虑一项任务,对一个人是否怀孕进行分类。实际怀孕(阳性)并被归类为怀孕(阳性的)的人称为真阳性(TP)。实际上未怀孕(阴性)并被归类为未怀孕(阳性)的人称为真阴性(TN)。实际上未怀孕(阴性)并被归类为怀孕(阳性)的人被称为假阳性(FP)。实际怀孕(阳性)并被归类为未怀孕(阴性)的人被称为假阴性(FN)。
通过将正确预测数除以预测总数来计算精度分数。精度可用于二进制和多类分类。当真正数和真负数更重要且类分布相似时使用。
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, 'accuracy', @score);
SELECT @score;
@score
0.9666666388511658
precision
精度是从所有预测阳性病例中正确识别阳性病例的度量。当假阳性的成本较高时,这是有用的。精度只能用于二元分类,但扩展度量(Precision_micro、Precision _macro)可用于多类分类。
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, 'precision_micro', @score);
SELECT @score;
@score
0.9666666388511658
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, 'precision_macro’, @score);
SELECT @score;
@score
0.9777777791023254
recall
召回是从所有实际阳性病例中正确识别阳性病例的量度。当假阴性的成本高时,这一点很重要。这是衡量模型能够正确预测的正类数量的指标。召回度量只能用于二元分类,但扩展度量(召回微、召回宏)可用于多类分类。
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, recall_micro', @score);
SELECT @score;
@score
0.966666638851
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, recall_macro’, @score);
SELECT @score;
@score
0.958333313465
balanced_accuracy
平衡精度矩阵避免了对不平衡数据集的膨胀性能估计。它用于二进制和多类分类问题。它定义为在每个类上获得的召回率的平均值。
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, 'balanced_accuracy', @score);
SELECT @score;
@score
0.9583333134651184
f1-score
f1-score是精确性和召回率的调和平均值,与准确度度量相比,它可以更好地衡量分类错误的情况。当假阴性和假阳性至关重要时使用。例如,在测试威胁生命的死亡时,假阴性和假阳性至关重要。当存在不平衡类时,f1分数是更好的度量。f1分数度量只能用于二元分类,但扩展度量(f1_micro,f1_macro)可用于多类分类。
CALL sys.ML_SCORE('Iris.iris_validate', 'class', @iris_model, 'f1_macro’, @score);
SELECT @score;
@score
0.9662835001945496
回归
基于Diamonds数据集构建回归模型并加载模型。
CALL sys.ML_TRAIN('Diamonds.diamonds_train', 'price', JSON_OBJECT('task', 'regression'), @diamonds_model);
CALL sys.ML_MODEL_LOAD(@diamonds_model, NULL);
下面提供的评分度量示例是使用diamonds_model创建的。
neg_mean_squared_error
对应于平方(二次)误差或损失值的风险度量。
CALL sys.ML_SCORE('Diamonds.diamonds_test', 'price', @diamonds_model, 'neg_mean_squared_error', @score);
SELECT @score;
@score
-0.007937146350741386
neg_mean_absolute_error
对应于绝对误差损失值的风险度量。
CALL sys.ML_SCORE('Diamonds.diamonds_test', 'price', @diamonds_model, 'neg_mean_absolute_error', @score);
SELECT @score;
@score
-0.02085324004292488
r2
r2分数表示模型中自变量解释的因变量方差的比例。它提供了拟合优度的指示,因此,通过解释方差的比例,可以衡量模型预测不可见样本的可能性。
CALL sys.ML_SCORE('Diamonds.diamonds_test', 'price', @diamonds_model, 'r2', @score)
SELECT @score;
@score
0.9831354022026062
neg_mean_squared_log_error
与对数(二次)误差或损失的平方值相对应的风险度量。当目标呈指数增长时,如人口数量、多年平均商品销售额等,最好使用此指标。
CALL sys.ML_SCORE('Diamonds.diamonds_test', 'price', @diamonds_model, 'neg_mean_squared_log_error', @score);
SELECT @score;
@score
-0.0009180943598039448
neg_median_absolute_error
该度量是通过取目标和预测之间的所有绝对差的中值来计算的。
CALL sys.ML_SCORE('Diamonds.diamonds_test', 'price', @diamonds_model, 'neg_median_absolute_error', @score);
SELECT @score;
@score
-0.011675260029733181
总之,MySQL HeatWave ML支持分类(二进制和多类)和回归模型的各种评分指标,以计算模型质量。当类分布相似时,应使用精度度量。当误报成本较高时,精度度量非常有用。当假阴性的成本较高时,应使用召回。当数据集不平衡时,平衡精度矩阵非常有用。当假阴性和假阳性至关重要时,使用F1分数度量。在回归模型中,R2分数表示因变量方差的比例,该比例已由模型中的自变量解释。用户需要记住,这些度量中的每一个都应该用于计算模型质量。
附录:
1.资料来源:https://scikit-learn.org/stable/modules/model_evaluation.html
2.度量公式
-
精度=(TP+TN)/(TP+TN+FP+FN)
-
精度=(TP)/(TP+FP)
-
召回=(TP)/(TP+FN)
-
平衡精度=½*(TP)/(TP+FP)+(TN)/(TN+FP)
-
f1-score=2 x(精确性x召回率)/(精确性+召回率)
关于作者
原文标题:Using various scoring metrics to evaluate MySQL HeatWave ML models
原文作者:Salil Pradhan
原文链接:https://blogs.oracle.com/mysql/post/using-various-scoring-metrics-to-evaluate-mysql-heatwave-ml-models
免责声明:
1、本站资源由自动抓取工具收集整理于网络。
2、本站不承担由于内容的合法性及真实性所引起的一切争议和法律责任。
3、电子书、小说等仅供网友预览使用,书籍版权归作者或出版社所有。
4、如作者、出版社认为资源涉及侵权,请联系本站,本站将在收到通知书后尽快删除您认为侵权的作品。
5、如果您喜欢本资源,请您支持作者,购买正版内容。
6、资源失效,请下方留言,欢迎分享资源链接
文章评论