diff --git a/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala b/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala new file mode 100644 index 00000000..593d358d --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala @@ -0,0 +1,92 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DeequFunctions.stateful_corr +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import Analyzers._ + +import com.amazon.deequ.metrics.Entity +import com.amazon.deequ.repository.AnalysisResultSerde + +case class RatioOfSumsState( + numerator: Double, + denominator: Double +) extends DoubleValuedState[RatioOfSumsState] { + + override def sum(other: RatioOfSumsState): RatioOfSumsState = { + RatioOfSumsState(numerator + other.numerator, denominator + other.denominator) + } + + override def metricValue(): Double = { + numerator / denominator + } +} + +/** Sums up 2 columns and then divides the final values as a Double. The columns + * can contain a mix of positive and negative numbers. Dividing by zero is allowed + * and will result in a value of Double.PositiveInfinity or Double.NegativeInfinity. + * + * @param numerator + * First input column for computation + * @param denominator + * Second input column for computation + */ +case class RatioOfSums( + numerator: String, + denominator: String, + where: Option[String] = None +) extends StandardScanShareableAnalyzer[RatioOfSumsState]( + "RatioOfSums", + s"$numerator,$denominator", + Entity.Multicolumn + ) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + val firstSelection = conditionalSelection(numerator, where) + val secondSelection = conditionalSelection(denominator, where) + sum(firstSelection).cast(DoubleType) :: sum(secondSelection).cast(DoubleType) :: Nil + } + + override def fromAggregationResult( + result: Row, + offset: Int + ): Option[RatioOfSumsState] = { + if (result.isNullAt(offset)) { + None + } else { + Some( + RatioOfSumsState( + result.getDouble(0), + result.getDouble(1) + ) + ) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(numerator) :: isNumeric(numerator) :: hasColumn(denominator) :: isNumeric(denominator) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala index e9bb4f7d..eb0db536 100644 --- a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala +++ b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala @@ -256,6 +256,12 @@ private[deequ] object AnalyzerSerializer result.addProperty(COLUMN_FIELD, sum.column) result.addProperty(WHERE_FIELD, sum.where.orNull) + case ratioOfSums: RatioOfSums => + result.addProperty(ANALYZER_NAME_FIELD, "RatioOfSums") + result.addProperty("numerator", ratioOfSums.numerator) + result.addProperty("denominator", ratioOfSums.denominator) + result.addProperty(WHERE_FIELD, ratioOfSums.where.orNull) + case mean: Mean => result.addProperty(ANALYZER_NAME_FIELD, "Mean") result.addProperty(COLUMN_FIELD, mean.column) @@ -412,6 +418,12 @@ private[deequ] object AnalyzerDeserializer json.get(COLUMN_FIELD).getAsString, getOptionalWhereParam(json)) + case "RatioOfSums" => + RatioOfSums( + json.get("numerator").getAsString, + json.get("denominator").getAsString, + getOptionalWhereParam(json)) + case "Mean" => Mean( json.get(COLUMN_FIELD).getAsString, diff --git a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala index abd68dde..4e0573b7 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala @@ -840,6 +840,23 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with analyzer.calculate(df).value shouldBe Success(2.0 / 8.0) assert(analyzer.calculate(df).fullColumn.isDefined) } + + "compute ratio of sums correctly for numeric data" in withSparkSession { sparkSession => + val df = getDfWithNumericValues(sparkSession) + RatioOfSums("att1", "att2").calculate(df).value shouldBe Success(21.0 / 18.0) + } + + "fail to compute ratio of sums for non numeric type" in withSparkSession { sparkSession => + val df = getDfFull(sparkSession) + assert(RatioOfSums("att1", "att2").calculate(df).value.isFailure) + } + + "divide by zero" in withSparkSession { sparkSession => + val df = getDfWithNumericValues(sparkSession) + val testVal = RatioOfSums("att1", "att2", Some("item IN ('1', '2')")).calculate(df) + assert(testVal.value.isSuccess) + assert(testVal.value.toOption.get.isInfinite) + } } } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala index 05f4d47b..1000ff8e 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala @@ -76,6 +76,8 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers { DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), Sum("ColumnA") -> DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), + RatioOfSums("ColumnA", "ColumnB") -> + DoubleMetric(Entity.Column, "RatioOfSums", "ColumnA", Success(5.0)), StandardDeviation("ColumnA") -> DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), DataType("ColumnA") ->