決策樹 - 基於 RDD 的 API

決策樹及其合奏是機器學習中分類和回歸任務的熱門方法。決策樹廣泛使用,因為它們易於解讀、處理分類特徵、延伸至多類別分類設定、不需要特徵縮放,而且能夠捕捉非線性和特徵互動。樹狀合奏演算法(例如隨機森林和提升)是分類和回歸任務的頂尖執行者。

spark.mllib 支援用於二元和多類別分類以及回歸的決策樹,同時使用連續和分類特徵。實作會按列分割資料,允許使用數百萬個實例進行分散式訓練。

樹狀合奏(隨機森林和梯度提升樹)在合奏指南中說明。

基本演算法

決策樹是一種貪婪演算法,對特徵空間執行遞迴二元分割。樹狀會預測每個最底層(葉子)分割的相同標籤。每個分割都是透過從一組可能的分割中選擇最佳分割來貪婪地選擇,以最大化樹狀節點的資訊增益。換句話說,在每個樹狀節點選擇的分割會從集合中選擇 $\underset{s}{\operatorname{argmax}} IG(D,s)$,其中 $IG(D,s)$ 是將分割 $s$ 套用到資料集 $D$ 時的資訊增益。

節點不純度和資訊增益

節點雜質是節點中標籤同質性的量測。目前的實作提供了兩種分類雜質量測(吉尼雜質和熵)和一種迴歸雜質量測(變異數)。

雜質任務公式說明
吉尼雜質 分類 $\sum_{i=1}^{C} f_i(1-f_i)$$f_i$ 是標籤 $i$ 在節點中的頻率,而 $C$ 是唯一標籤的數量。
分類 $\sum_{i=1}^{C} -f_ilog(f_i)$$f_i$ 是標籤 $i$ 在節點中的頻率,而 $C$ 是唯一標籤的數量。
變異數 回歸 $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ 是實例的標籤,$N$ 是實例的數量,而 $\mu$ 是由 $\frac{1}{N} \sum_{i=1}^N y_i$ 給出的平均值。

資訊增益是父節點雜質和兩個子節點雜質的加權總和之間的差異。假設分割 $s$ 將大小為 $N$ 的資料集 $D$ 分割成大小分別為 $N_{left}$$N_{right}$ 的兩個資料集 $D_{left}$$D_{right}$,則資訊增益為

$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$

分割候選

連續特徵

對於單機實作中的小型資料集,每個連續特徵的分割候選通常是該特徵的唯一值。有些實作會先對特徵值排序,然後使用已排序的唯一值作為分割候選,以加快樹狀結構的計算。

對於大型分布式資料集,對特徵值排序的成本很高。此實作會透過對資料的抽樣部分執行分位數計算,來計算一組近似的分割候選。已排序的分割會建立「區間」,而此類區間的最大數量可以使用 maxBins 參數指定。

請注意,區間的數量不能大於實例數量 $N$(由於預設 maxBins 值為 32,因此這是一個罕見的情況)。如果未滿足此條件,樹狀結構演算法會自動減少區間數量。

類別特徵

對於一個具有 $M$ 個可能值(類別)的分類特徵,可以產生 $2^{M-1}-1$ 個分割候選。對於二元(0/1)分類和迴歸,我們可以透過根據平均標籤對分類特徵值進行排序,將分割候選的數量減少到 $M-1$。(有關詳細資訊,請參閱 統計機器學習元素 中的第 9.2.4 節。)例如,對於一個二元分類問題,其中一個分類特徵具有三個類別 A、B 和 C,其對應的標籤 1 比例分別為 0.2、0.6 和 0.4,則分類特徵的排序為 A、C、B。這兩個分割候選是 A | C, B 和 A , C | B,其中 | 表示分割。

在多類別分類中,所有 $2^{M-1}-1$ 個可能的分割都會在可能的情況下使用。當 $2^{M-1}-1$ 大於 maxBins 參數時,我們會使用類似於二元分類和迴歸所用方法的(啟發式)方法。將 $M$ 個分類特徵值按雜質排序,並考慮產生的 $M-1$ 個分割候選。

停止規則

當符合下列條件之一時,遞迴樹的建構會在節點停止

  1. 節點深度等於 maxDepth 訓練參數。
  2. 沒有分割候選會導致資訊增益大於 minInfoGain
  3. 沒有分割候選會產生子節點,而每個子節點至少有 minInstancesPerNode 個訓練實例。

使用提示

我們透過討論各種參數,提供一些使用決策樹的指南。這些參數大致按重要性遞減的順序列出。新使用者應主要考慮「問題規格參數」部分和 maxDepth 參數。

問題規格參數

這些參數描述您要解決的問題和您的資料集。它們應該被指定,不需要調整。

停止準則

這些參數決定樹何時停止建立(新增新節點)。在調整這些參數時,請務必驗證保留的測試資料,以避免過度擬合。

可調整參數

這些參數可以調整。調整時請務必驗證保留的測試資料,以避免過度擬合。

快取和檢查點

MLlib 1.2 為擴展至較大的(較深的)樹和樹集合新增了數個功能。當 maxDepth 設為較大值時,開啟節點 ID 快取和檢查點會很有用。當 numTrees 設為較大值時,這些參數對於 RandomForest 也很有用。

節點 ID 快取會產生一系列 RDD(每個反覆運算 1 個)。此長譜系可能會造成效能問題,但檢查點中介 RDD 可以緩解這些問題。請注意,只有在 useNodeIdCache 設為 true 時,檢查點才適用。

擴充

運算會隨著訓練實例數、特徵數和 maxBins 參數近似線性擴展。通訊會隨著特徵數和 maxBins 近似線性擴展。

實作的演算法會讀取稀疏和稠密資料。然而,它並未針對稀疏輸入最佳化。

範例

分類

以下範例說明如何載入 LIBSVM 資料檔案,將其解析為 LabeledPoint 的 RDD,然後使用決定樹執行分類,其中 Gini 不純度為不純度量度,最大樹深度為 5。會計算測試誤差來測量演算法準確度。

請參閱 DecisionTree Python 文件DecisionTreeModel Python 文件,以取得有關 API 的更多詳細資訊。

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                     impurity='gini', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
    lambda lp: lp[0] != lp[1]).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
在 Spark 儲存庫中的「examples/src/main/python/mllib/decision_tree_classification_example.py」中,找出完整的範例程式碼。

有關 API 的詳細資訊,請參閱 DecisionTree Scala 文件DecisionTreeModel Scala 文件

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
在 Spark 儲存庫中的「examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala」中找到完整的範例程式碼。

有關 API 的詳細資訊,請參閱 DecisionTree Java 文件DecisionTreeModel Java 文件

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model for classification.
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testErr =
  predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();

System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
在 Spark 儲存庫中的「examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java」中找到完整的範例程式碼。

回歸

以下範例示範如何載入 LIBSVM 資料檔,將其剖析為 LabeledPoint 的 RDD,然後使用變異數作為不純淨度量度和最大樹狀深度為 5 的決策樹執行回歸。平均平方誤差 (MSE) 會在最後計算,以評估 擬合優度

請參閱 DecisionTree Python 文件DecisionTreeModel Python 文件,以取得有關 API 的更多詳細資訊。

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
                                    impurity='variance', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
    float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
在 Spark 儲存庫中的「examples/src/main/python/mllib/decision_tree_regression_example.py」中找到完整的範例程式碼。

有關 API 的詳細資訊,請參閱 DecisionTree Scala 文件DecisionTreeModel Scala 文件

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
  maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
println(s"Test Mean Squared Error = $testMSE")
println(s"Learned regression tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
在 Spark 儲存庫中的「examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala」中找到完整的範例程式碼。

有關 API 的詳細資訊,請參閱 DecisionTree Java 文件DecisionTreeModel Java 文件

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "variance";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model.
DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testMSE = predictionAndLabel.mapToDouble(pl -> {
  double diff = pl._1() - pl._2();
  return diff * diff;
}).mean();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
在 Spark 儲存庫中的「examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java」中找到完整的範例程式碼。