diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index dd5fb07e..bc05adb5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -262,8 +262,13 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } +sealed trait RowLevelStatusSource { def name: String } +case object InScopeData extends RowLevelStatusSource { val name = "InScopeData" } +case object FilteredData extends RowLevelStatusSource { val name = "FilteredData" } + case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE) + object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value @@ -478,34 +483,34 @@ private[deequ] object Analyzers { if (columns.size == 1) Entity.Column else Entity.Multicolumn } - def conditionalSelection(selection: String, where: Option[String]): Column = { - conditionalSelection(col(selection), where) + def conditionalSelection(selection: String, condition: Option[String]): Column = { + conditionalSelection(col(selection), condition) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Double): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Double): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: String): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: String): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Boolean): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { - conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + def conditionalSelection(selection: Column, condition: Option[String], replaceWith: Double): Column = { + conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: String): Column = { - conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + def conditionalSelection(selection: Column, condition: Option[String], replaceWith: String): Column = { + conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith) } def conditionalSelection(selection: Column, condition: Option[String]): Column = { @@ -513,11 +518,20 @@ private[deequ] object Analyzers { conditionalSelectionFromColumns(selection, conditionColumn) } - def conditionalSelectionFilteredFromColumns( - selection: Column, - conditionColumn: Option[Column], - filterTreatment: FilteredRowOutcome) - : Column = { + def conditionalSelectionWithAugmentedOutcome(selection: Column, + condition: Option[String], + replaceWith: Double): Column = { + val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection")) + val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection")) + + condition + .map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) } + .getOrElse(origSelection) + } + + def conditionalSelectionFilteredFromColumns(selection: Column, + conditionColumn: Option[Column], + filterTreatment: FilteredRowOutcome): Column = { conditionColumn .map { condition => when(not(condition), filterTreatment.getExpression).when(condition, selection) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index c5cc33f9..abeee6d9 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, max} +import org.apache.spark.sql.functions.{col, element_at, max} import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.functions.expr -import org.apache.spark.sql.functions.not case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -43,13 +41,12 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion) :: Nil + max(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { - ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(rowLevelResults)) + MaxState(result.getDouble(offset), Some(criterion)) } } @@ -60,19 +57,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) - - private[deequ] def rowLevelResults: Column = { - val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) - val whereNotCondition = where.map { expression => not(expr(expression)) } - - filteredRowOutcome match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType) - case _ => - criterion - } - } - + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue) } - diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index 18640dc1..b17507fc 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, min} +import org.apache.spark.sql.functions.{col, element_at, min} import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.functions.expr -import org.apache.spark.sql.functions.not case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -43,12 +41,12 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion) :: Nil + min(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(rowLevelResults)) + MinState(result.getDouble(offset), Some(criterion)) } } @@ -59,19 +57,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = { - conditionalSelection(column, where).cast(DoubleType) - } - - private[deequ] def rowLevelResults: Column = { - val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) - val whereNotCondition = where.map { expression => not(expr(expression)) } - - filteredRowOutcome match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType) - case _ => - criterion - } - } + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue) } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 7e7ea5a3..a28b6f2e 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -629,7 +629,9 @@ object Constraint { val constraint = AnalysisBasedConstraint[MinState, Double, Double](minimum, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertion(assertion, minimum.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) + new RowLevelAssertedConstraint( constraint, s"MinimumConstraint($minimum)", @@ -663,7 +665,9 @@ object Constraint { val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maximum, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertion(assertion, maximum.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) + new RowLevelAssertedConstraint( constraint, s"MaximumConstraint($maximum)", @@ -916,6 +920,59 @@ object Constraint { .getOrElse(0.0) } + + /* + * This function is used by Min/Max constraints and it creates a new assertion based on the provided assertion. + * Each value in the outcome column is an array of 2 elements. + * - The first element is a string that denotes whether the row is the filtered dataset or not. + * - The second element is the actual value of the constraint's target column. + * The result of the final assertion is one of 3 states: true, false or null. + * These values can be tuned using the analyzer options. + * Null outcome allows the consumer to decide how to treat filtered rows or rows that were originally null. + */ + private[this] def getUpdatedRowLevelAssertion(assertion: Double => Boolean, + analyzerOptions: Option[AnalyzerOptions]) + : Seq[String] => java.lang.Boolean = { + (d: Seq[String]) => { + val (scope, value) = (d.head, Option(d.last).map(_.toDouble)) + + def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = { + if (value.isDefined) { + // If value is defined, run it through the assertion. + assertion(value.get) + } else { + // If value is not defined (value is null), apply NullBehavior. + analyzerOptions match { + case Some(opts) => + opts.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore | NullBehavior.EmptyString => null + } + case None => null + } + } + } + + def filteredRowOutcome: java.lang.Boolean = { + analyzerOptions match { + case Some(opts) => + opts.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + // https://github.com/awslabs/deequ/issues/530 + // Filtered rows should be marked as true by default. + // They can be set to null using the FilteredRowOutcome option. + case None => true + } + } + + scope match { + case FilteredData.name => filteredRowOutcome + case InScopeData.name => inScopeRowOutcome(value) + } + } + } } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 9da41562..f7684e96 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -41,13 +41,10 @@ import org.scalamock.scalatest.MockFactory import org.scalatest.Matchers import org.scalatest.WordSpec - - class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec with FixtureSupport with MockFactory { "Verification Suite" should { - "return the correct verification status regardless of the order of checks" in withSparkSession { sparkSession => @@ -1609,6 +1606,189 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite with == based Min/Max checks and filtered row behavior" should { + val col1 = "att1" + val col2 = "att2" + val col3 = "att3" + + val check1Description = "equality-check-1" + val check2Description = "equality-check-2" + val check3Description = "equality-check-3" + + val check1WhereClause = "att1 > 3" + val check2WhereClause = "att2 > 4" + val check3WhereClause = "att3 = 0" + + def mkEqualityCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description) + .hasMin(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + .hasMax(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + + def mkEqualityCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description) + .hasMin(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + .hasMax(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + + def mkEqualityCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description) + .hasMin(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + .hasMax(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertCheckResults(verificationResult: VerificationResult): Unit = { + val passResult = verificationResult.checkResults + + val equalityCheck1Result = passResult.values.find(_.check.description == check1Description) + val equalityCheck2Result = passResult.values.find(_.check.description == check2Description) + val equalityCheck3Result = passResult.values.find(_.check.description == check3Description) + + assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Error) + assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error) + assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success) + } + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description)) + val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description)) + val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description)) + + val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + + assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, true, false, false)) + assert(equalityCheck2Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, false, false, true)) + assert(equalityCheck3Results == Seq(true, true, true, filteredOutcome, filteredOutcome, filteredOutcome)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col1|Minimum (where: $check1WhereClause)") == 4.0) + assert(metricsMap(s"$col1|Maximum (where: $check1WhereClause)") == 6.0) + assert(metricsMap(s"$col2|Minimum (where: $check2WhereClause)") == 5.0) + assert(metricsMap(s"$col2|Maximum (where: $check2WhereClause)") == 7.0) + assert(metricsMap(s"$col3|Minimum (where: $check3WhereClause)") == 0.0) + assert(metricsMap(s"$col3|Maximum (where: $check3WhereClause)") == 0.0) + } + + "mark filtered rows as null" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL) + + val equalityCheck1 = mkEqualityCheck1(analyzerOptions) + val equalityCheck2 = mkEqualityCheck2(analyzerOptions) + val equalityCheck3 = mkEqualityCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + + "mark filtered rows as true" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE) + + val equalityCheck1 = mkEqualityCheck1(analyzerOptions) + val equalityCheck2 = mkEqualityCheck2(analyzerOptions) + val equalityCheck3 = mkEqualityCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + } + + "Verification Suite with == based Min/Max checks and null row behavior" should { + val col = "attNull" + val checkDescription = "equality-check" + def mkEqualityCheck(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, checkDescription) + .hasMin(col, _ == 6, analyzerOptions = Some(analyzerOptions)) + .hasMax(col, _ == 6, analyzerOptions = Some(analyzerOptions)) + + def assertCheckResults(verificationResult: VerificationResult, checkStatus: CheckStatus.Value): Unit = { + val passResult = verificationResult.checkResults + val equalityCheckResult = passResult.values.find(_.check.description == checkDescription) + assert(equalityCheckResult.isDefined && equalityCheckResult.get.status == checkStatus) + } + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheckResults = getRowLevelResults(rowLevelResults.select(checkDescription)) + val nullOutcome: java.lang.Boolean = analyzerOptions.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + } + + assert(equalityCheckResults == Seq(nullOutcome, nullOutcome, nullOutcome, false, true, false)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col|Minimum") == 5.0) + assert(metricsMap(s"$col|Maximum") == 7.0) + } + + "keep non-filtered null rows as null" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Ignore) + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(mkEqualityCheck(analyzerOptions))) + .run() + + val passResult = verificationResult.checkResults + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF) + } + + "mark non-filtered null rows as false" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Fail) + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(mkEqualityCheck(analyzerOptions))) + .run() + + val passResult = verificationResult.checkResults + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF) + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { @@ -1633,4 +1813,13 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec private[this] def assertSameRows(dataframeA: DataFrame, dataframeB: DataFrame): Unit = { assert(dataframeA.collect().toSet == dataframeB.collect().toSet) } + + private[this] def getMetricsAsMap(metricsDF: DataFrame): Map[String, Double] = { + metricsDF.collect().map { r => + val colName = r.getAs[String]("instance") + val metricName = r.getAs[String]("name") + val metricValue = r.getAs[Double]("value") + s"$colName|$metricName" -> metricValue + }.toMap + } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index 6ac90f73..1d13a8df 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -21,10 +21,20 @@ import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "Max" should { "return row-level results for columns" in withSparkSession { session => @@ -35,8 +45,8 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F val state: Option[MaxState] = att1Maximum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) } "return row-level results for columns with null" in withSparkSession { session => @@ -47,40 +57,9 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F val state: Option[MaxState] = att1Maximum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(null, null, null, 5.0, 6.0, 7.0) - } - - "return row-level results for columns with where clause filtered as true" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Maximum = Maximum("att1", Option("item < 4")) - val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) - val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue) - } - - "return row-level results for columns with where clause filtered as null" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Maximum = Maximum("att1", Option("item < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) - val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) - val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, null, null, null) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala index 435542e8..8d1d2dd6 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala @@ -14,77 +14,49 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MinimumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "Min" should { "return row-level results for columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Minimum = Minimum("att1") val state: Option[MinState] = att1Minimum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) } "return row-level results for columns with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Minimum = Minimum("attNull") val state: Option[MinState] = att1Minimum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(null, null, null, 5.0, 6.0, 7.0) - } - - "return row-level results for columns with where clause filtered as true" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Minimum = Minimum("att1", Option("item < 4")) - val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) - print(state) - val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, Double.MaxValue, Double.MaxValue, Double.MaxValue) - } - - "return row-level results for columns with where clause filtered as null" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Minimum = Minimum("att1", Option("item < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) - val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) - print(state) - val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, null, null, null) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } } - }