決策樹 - 基於 RDD 的 API
決策樹及其合奏是機器學習中分類和回歸任務的熱門方法。決策樹廣泛使用,因為它們易於解讀、處理分類特徵、延伸至多類別分類設定、不需要特徵縮放,而且能夠捕捉非線性和特徵互動。樹狀合奏演算法(例如隨機森林和提升)是分類和回歸任務的頂尖執行者。
spark.mllib
支援用於二元和多類別分類以及回歸的決策樹,同時使用連續和分類特徵。實作會按列分割資料,允許使用數百萬個實例進行分散式訓練。
樹狀合奏(隨機森林和梯度提升樹)在合奏指南中說明。
基本演算法
決策樹是一種貪婪演算法,對特徵空間執行遞迴二元分割。樹狀會預測每個最底層(葉子)分割的相同標籤。每個分割都是透過從一組可能的分割中選擇最佳分割來貪婪地選擇,以最大化樹狀節點的資訊增益。換句話說,在每個樹狀節點選擇的分割會從集合中選擇 $\underset{s}{\operatorname{argmax}} IG(D,s)$
,其中 $IG(D,s)$
是將分割 $s$
套用到資料集 $D$
時的資訊增益。
節點不純度和資訊增益
節點雜質是節點中標籤同質性的量測。目前的實作提供了兩種分類雜質量測(吉尼雜質和熵)和一種迴歸雜質量測(變異數)。
雜質 | 任務 | 公式 | 說明 |
---|---|---|---|
吉尼雜質 | 分類 | $\sum_{i=1}^{C} f_i(1-f_i)$ | $f_i$ 是標籤 $i$ 在節點中的頻率,而 $C$ 是唯一標籤的數量。 |
熵 | 分類 | $\sum_{i=1}^{C} -f_ilog(f_i)$ | $f_i$ 是標籤 $i$ 在節點中的頻率,而 $C$ 是唯一標籤的數量。 |
變異數 | 回歸 | $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$ | $y_i$ 是實例的標籤,$N$ 是實例的數量,而 $\mu$ 是由 $\frac{1}{N} \sum_{i=1}^N y_i$ 給出的平均值。 |
資訊增益是父節點雜質和兩個子節點雜質的加權總和之間的差異。假設分割 $s$ 將大小為 $N$
的資料集 $D$
分割成大小分別為 $N_{left}$
和 $N_{right}$
的兩個資料集 $D_{left}$
和 $D_{right}$
,則資訊增益為
$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$
分割候選
連續特徵
對於單機實作中的小型資料集,每個連續特徵的分割候選通常是該特徵的唯一值。有些實作會先對特徵值排序,然後使用已排序的唯一值作為分割候選,以加快樹狀結構的計算。
對於大型分布式資料集,對特徵值排序的成本很高。此實作會透過對資料的抽樣部分執行分位數計算,來計算一組近似的分割候選。已排序的分割會建立「區間」,而此類區間的最大數量可以使用 maxBins
參數指定。
請注意,區間的數量不能大於實例數量 $N$
(由於預設 maxBins
值為 32,因此這是一個罕見的情況)。如果未滿足此條件,樹狀結構演算法會自動減少區間數量。
類別特徵
對於一個具有 $M$
個可能值(類別)的分類特徵,可以產生 $2^{M-1}-1$
個分割候選。對於二元(0/1)分類和迴歸,我們可以透過根據平均標籤對分類特徵值進行排序,將分割候選的數量減少到 $M-1$
。(有關詳細資訊,請參閱 統計機器學習元素 中的第 9.2.4 節。)例如,對於一個二元分類問題,其中一個分類特徵具有三個類別 A、B 和 C,其對應的標籤 1 比例分別為 0.2、0.6 和 0.4,則分類特徵的排序為 A、C、B。這兩個分割候選是 A | C, B 和 A , C | B,其中 | 表示分割。
在多類別分類中,所有 $2^{M-1}-1$
個可能的分割都會在可能的情況下使用。當 $2^{M-1}-1$
大於 maxBins
參數時,我們會使用類似於二元分類和迴歸所用方法的(啟發式)方法。將 $M$
個分類特徵值按雜質排序,並考慮產生的 $M-1$
個分割候選。
停止規則
當符合下列條件之一時,遞迴樹的建構會在節點停止
- 節點深度等於
maxDepth
訓練參數。 - 沒有分割候選會導致資訊增益大於
minInfoGain
。 - 沒有分割候選會產生子節點,而每個子節點至少有
minInstancesPerNode
個訓練實例。
使用提示
我們透過討論各種參數,提供一些使用決策樹的指南。這些參數大致按重要性遞減的順序列出。新使用者應主要考慮「問題規格參數」部分和 maxDepth
參數。
問題規格參數
這些參數描述您要解決的問題和您的資料集。它們應該被指定,不需要調整。
-
algo
:決策樹類型,可以是分類
或回歸
。 -
numClasses
:類別數目(僅適用於分類
)。 -
categoricalFeaturesInfo
:指定哪些特徵是分類的,以及這些特徵中的每個特徵可以採用多少個分類值。這以特徵索引對應到特徵基數(類別數目)的映射方式提供。此映射中沒有的特徵將被視為連續的。- 例如,
Map(0 -> 2, 4 -> 10)
指定特徵0
是二元的(採用值0
或1
),而特徵4
有 10 個類別(值{0, 1, ..., 9}
)。請注意,特徵索引是從 0 開始的:特徵0
和4
是實例特徵向量的第 1 個和第 5 個元素。 - 請注意,您不必指定
categoricalFeaturesInfo
。演算法仍會執行,並且可能會獲得合理的結果。但是,如果正確指定分類特徵,效能應該會更好。
- 例如,
停止準則
這些參數決定樹何時停止建立(新增新節點)。在調整這些參數時,請務必驗證保留的測試資料,以避免過度擬合。
-
maxDepth
:樹的最大深度。較深的樹更具表達力(可能允許更高的準確度),但它們的訓練成本也更高,而且更有可能過度擬合。 -
minInstancesPerNode
:對於要進一步分割的節點,其每個子節點都必須接收至少這麼多個訓練實例。這通常與 RandomForest 一起使用,因為它們通常訓練得比個別樹更深入。 -
minInfoGain
:對於要進一步分割的節點,分割必須至少改善這麼多(在資訊增益方面)。
可調整參數
這些參數可以調整。調整時請務必驗證保留的測試資料,以避免過度擬合。
maxBins
:離散化連續特徵時使用的垃圾桶數量。- 增加
maxBins
可讓演算法考慮更多分割候選,並做出細緻的分割決策。但是,它也會增加運算和通訊。 - 請注意,
maxBins
參數必須至少為任何類別特徵的最大類別數$M$
。
- 增加
maxMemoryInMB
:收集足夠統計資料要使用的記憶體量。- 預設值保守地選擇為 256 MiB,以允許決策演算法在大部分情況下運作。增加
maxMemoryInMB
可讓資料傳遞次數減少,進而加快訓練速度(如果記憶體足夠)。但是,隨著maxMemoryInMB
增加,報酬遞減,因為每次反覆運算的通訊量可能與maxMemoryInMB
成正比。 - 實作細節:為了加快處理速度,決策樹演算法會收集關於要分割的節點群組(而非一次 1 個節點)的統計資料。一個群組可以處理的節點數由記憶體需求決定(因特徵而異)。
maxMemoryInMB
參數以百萬位元組為單位指定每個工作者可使用於這些統計資料的記憶體限制。
- 預設值保守地選擇為 256 MiB,以允許決策演算法在大部分情況下運作。增加
-
subsamplingRate
:用於學習決策樹的訓練資料比例。此參數與訓練樹系(使用RandomForest
和GradientBoostedTrees
)最相關,其中對原始資料進行子抽樣可能會很有用。對於訓練單一決策樹,此參數較不實用,因為訓練實例數通常不是主要限制。 impurity
:用於選擇候選分割的雜質測量(如上所述)。此測量必須與algo
參數相符。
快取和檢查點
MLlib 1.2 為擴展至較大的(較深的)樹和樹集合新增了數個功能。當 maxDepth
設為較大值時,開啟節點 ID 快取和檢查點會很有用。當 numTrees
設為較大值時,這些參數對於 RandomForest 也很有用。
useNodeIdCache
:如果設為 true,演算法會避免在每次反覆運算中將目前的模型(樹或樹)傳遞給執行器。- 這對於深度樹(加速工作站上的運算)和大型隨機森林(減少每次反覆運算的通訊)很有用。
- 實作詳細資訊:預設情況下,演算法會將目前的模型傳遞給執行器,以便執行器可以將訓練實例與樹節點配對。當此設定開啟時,演算法會快取此資訊。
節點 ID 快取會產生一系列 RDD(每個反覆運算 1 個)。此長譜系可能會造成效能問題,但檢查點中介 RDD 可以緩解這些問題。請注意,只有在 useNodeIdCache
設為 true 時,檢查點才適用。
-
checkpointDir
:用於檢查點節點 ID 快取 RDD 的目錄。 -
checkpointInterval
:檢查點節點 ID 快取 RDD 的頻率。將此設定設得太低會因為寫入 HDFS 而造成額外負擔;將此設定設得太高可能會在執行器發生故障且 RDD 需要重新運算時造成問題。
擴充
運算會隨著訓練實例數、特徵數和 maxBins
參數近似線性擴展。通訊會隨著特徵數和 maxBins
近似線性擴展。
實作的演算法會讀取稀疏和稠密資料。然而,它並未針對稀疏輸入最佳化。
範例
分類
以下範例說明如何載入 LIBSVM 資料檔案,將其解析為 LabeledPoint
的 RDD,然後使用決定樹執行分類,其中 Gini 不純度為不純度量度,最大樹深度為 5。會計算測試誤差來測量演算法準確度。
請參閱 DecisionTree
Python 文件 和 DecisionTreeModel
Python 文件,以取得有關 API 的更多詳細資訊。
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
impurity='gini', maxDepth=5, maxBins=32)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
lambda lp: lp[0] != lp[1]).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
有關 API 的詳細資訊,請參閱 DecisionTree
Scala 文件 和 DecisionTreeModel
Scala 文件。
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")
// Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
有關 API 的詳細資訊,請參閱 DecisionTree
Java 文件 和 DecisionTreeModel
Java 文件。
import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;
// Train a DecisionTree model for classification.
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testErr =
predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
DecisionTreeModel sameModel = DecisionTreeModel
.load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
回歸
以下範例示範如何載入 LIBSVM 資料檔,將其剖析為 LabeledPoint
的 RDD,然後使用變異數作為不純淨度量度和最大樹狀深度為 5 的決策樹執行回歸。平均平方誤差 (MSE) 會在最後計算,以評估 擬合優度。
請參閱 DecisionTree
Python 文件 和 DecisionTreeModel
Python 文件,以取得有關 API 的更多詳細資訊。
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
impurity='variance', maxDepth=5, maxBins=32)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
有關 API 的詳細資訊,請參閱 DecisionTree
Scala 文件 和 DecisionTreeModel
Scala 文件。
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
println(s"Test Mean Squared Error = $testMSE")
println(s"Learned regression tree model:\n ${model.toDebugString}")
// Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
有關 API 的詳細資訊,請參閱 DecisionTree
Java 文件 和 DecisionTreeModel
Java 文件。
import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "variance";
int maxDepth = 5;
int maxBins = 32;
// Train a DecisionTree model.
DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testMSE = predictionAndLabel.mapToDouble(pl -> {
double diff = pl._1() - pl._2();
return diff * diff;
}).mean();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
DecisionTreeModel sameModel = DecisionTreeModel
.load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");