使用者定義的聚合函數 (UDAF)
說明
使用者定義的聚合函數 (UDAF) 是使用者可編程的常式,一次作用於多列,並傳回單一聚合值作為結果。此文件列出建立和註冊 UDAF 所需的類別。它也包含範例,說明如何在 Scala 中定義和註冊 UDAF,並在 Spark SQL 中呼叫它們。
Aggregator[-IN, BUF, OUT]
使用者定義聚合的基本類別,可用於 Dataset 作業,以取得群組的所有元素,並將它們簡化為單一值。
IN - 聚合的輸入類型。
BUF - 簡化的中間值類型。
OUT - 最終輸出結果的類型。
-
bufferEncoder: Encoder[BUF]
指定中間值類型的 Encoder。
-
finish(reduction: BUF): OUT
轉換簡化的輸出。
-
merge(b1: BUF, b2: BUF): BUF
合併兩個中間值。
-
outputEncoder: Encoder[OUT]
指定最終輸出值類型的 Encoder。
-
reduce(b: BUF, a: IN): BUF
將輸入值
a
聚合到目前的中間值。為了效能,函數可能會修改b
,並傳回它,而不是為b
建立新的物件。 -
zero: BUF
此聚合的中間結果的初始值。
範例
類型安全的使用者定義聚合函數
針對強類型 Dataset 的使用者定義聚合圍繞 Aggregator 抽象類別。例如,類型安全的使用者定義平均值可能如下所示
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Average] = Encoders.product
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// Convert the function to a `TypedColumn` and give it a name
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
在 Spark 儲存庫中「examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala」中找到完整的範例程式碼。
import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.expressions.Aggregator;
public static class Employee implements Serializable {
private String name;
private long salary;
// Constructors, getters, setters...
}
public static class Average implements Serializable {
private long sum;
private long count;
// Constructors, getters, setters...
}
public static class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
String path = "examples/src/main/resources/employees.json";
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
在 Spark 儲存庫中「examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java」中找到完整的範例程式碼。
未分型的使用者定義聚合函數
如上所述,分型的聚合也可以註冊為未分型的聚合 UDF,以與 DataFrames 搭配使用。例如,未分型 DataFrames 的使用者定義平均值可能如下所示
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Long, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, data: Long): Average = {
buffer.sum += data
buffer.count += 1
buffer
}
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Average] = Encoders.product
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
// Register the function to access it
spark.udf.register("myAverage", functions.udaf(MyAverage))
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
在 Spark 儲存庫中「examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala」中找到完整的範例程式碼。
import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;
public static class Average implements Serializable {
private long sum;
private long count;
// Constructors, getters, setters...
public Average() {
}
public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
public long getSum() {
return sum;
}
public void setSum(long sum) {
this.sum = sum;
}
public long getCount() {
return count;
}
public void setCount(long count) {
this.count = count;
}
}
public static class MyAverage extends Aggregator<Long, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Long data) {
long newSum = buffer.getSum() + data;
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
// Register the function to access it
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
在 Spark 儲存庫中「examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java」中找到完整的範例程式碼。
-- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.
CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar';
SHOW USER FUNCTIONS;
+------------------+
| function|
+------------------+
| default.myAverage|
+------------------+
CREATE TEMPORARY VIEW employees
USING org.apache.spark.sql.json
OPTIONS (
path "examples/src/main/resources/employees.json"
);
SELECT * FROM employees;
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+
SELECT myAverage(salary) as average_salary FROM employees;
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+