Skip to content

Commit

Permalink
Feature: Add Row Level Result Treatment Options for Uniqueness and Co…
Browse files Browse the repository at this point in the history
…mpleteness (#532)

* Modified Completeness analyzer to label filtered rows as null for row-level results

* Modified GroupingAnalyzers and Uniqueness analyzer to label filtered rows as null for row-level results

* Adjustments for modifying the calculate method to take in a filterCondition

* Add RowLevelFilterTreatement trait and object to determine how filtered rows will be labeled (default True)

* Modify VerificationRunBuilder to have RowLevelFilterTreatment as variable instead of extending, create RowLevelAnalyzer trait

* Do row-level filtering in AnalyzerOptions rather than with RowLevelFilterTreatment trait

* Modify computeStateFrom to take in optional filterCondition
  • Loading branch information
eycho-am authored Feb 15, 2024
1 parent 4a22e5b commit 5b818be
Show file tree
Hide file tree
Showing 25 changed files with 528 additions and 78 deletions.
3 changes: 1 addition & 2 deletions src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down Expand Up @@ -159,7 +159,6 @@ class VerificationRunBuilder(val data: DataFrame) {
new VerificationRunBuilderWithSparkSession(this, Option(sparkSession))
}


def run(): VerificationResult = {
VerificationSuite().doVerificationRun(
data,
Expand Down
31 changes: 25 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
Expand Down Expand Up @@ -69,7 +70,7 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
* @param data data frame
* @return
*/
def computeStateFrom(data: DataFrame): Option[S]
def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[S]

/**
* Compute the metric from the state (sufficient statistics)
Expand Down Expand Up @@ -97,13 +98,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
def calculate(
data: DataFrame,
aggregateWith: Option[StateLoader] = None,
saveStatesWith: Option[StatePersister] = None)
saveStatesWith: Option[StatePersister] = None,
filterCondition: Option[String] = None)
: M = {

try {
preconditions.foreach { condition => condition(data.schema) }

val state = computeStateFrom(data)
val state = computeStateFrom(data, filterCondition)

calculateMetric(state, aggregateWith, saveStatesWith)
} catch {
Expand Down Expand Up @@ -170,7 +172,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

}

/** An analyzer that runs a set of aggregation functions over the data,
Expand All @@ -184,7 +185,7 @@ trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S,
private[deequ] def fromAggregationResult(result: Row, offset: Int): Option[S]

/** Runs aggregation functions directly, without scan sharing */
override def computeStateFrom(data: DataFrame): Option[S] = {
override def computeStateFrom(data: DataFrame, where: Option[String] = None): Option[S] = {
val aggregations = aggregationFunctions()
val result = data.agg(aggregations.head, aggregations.tail: _*).collect().head
fromAggregationResult(result, 0)
Expand Down Expand Up @@ -255,12 +256,18 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore)
case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.TRUE)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}

/** Base class for analyzers that compute ratios of matching predicates */
abstract class PredicateMatchingAnalyzer(
name: String,
Expand Down Expand Up @@ -490,6 +497,18 @@ private[deequ] object Analyzers {
conditionalSelectionFromColumns(selection, conditionColumn)
}

def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: String)
: Column = {
conditionColumn
.map { condition => {
when(not(condition), expr(filterTreatment)).when(condition, selection)
} }
.getOrElse(selection)
}

private[this] def conditionalSelectionFromColumns(
selection: Column,
conditionColumn: Option[Column])
Expand Down
22 changes: 18 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{Column, Row}

/** Completeness is the fraction of non-null values in a column of a DataFrame. */
case class Completeness(column: String, where: Option[String] = None) extends
case class Completeness(column: String, where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None) extends
StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with
FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion))
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

Expand All @@ -51,4 +53,16 @@ case class Completeness(column: String, where: Option[String] = None) extends

@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}
2 changes: 1 addition & 1 deletion src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double
* @param data data frame
* @return
*/
override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = {
override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[CustomSqlState] = {

Try {
data.sqlContext.sql(expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame,
matchColumnMappings: Option[Map[String, String]] = None)
extends Analyzer[DatasetMatchState, DoubleMetric] {

override def computeStateFrom(data: DataFrame): Option[DatasetMatchState] = {
override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[DatasetMatchState] = {

val result = if (matchColumnMappings.isDefined) {
DataSynchronization.columnMatch(data, dfToCompare, columnMappings, matchColumnMappings.get, assertion)
Expand Down
16 changes: 13 additions & 3 deletions src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ import org.apache.spark.sql.functions.count
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.when

/** Base class for all analyzers that operate the frequencies of groups in the data */
abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String])
extends GroupingAnalyzer[FrequenciesAndNumRows, DoubleMetric] {

override def groupingColumns(): Seq[String] = { columnsToGroupOn }

override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = {
Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns()))
override def computeStateFrom(data: DataFrame,
filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = {
Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns(), filterCondition))
}

/** We need at least one grouping column, and all specified columns must exist */
Expand Down Expand Up @@ -88,7 +90,15 @@ object FrequencyBasedAnalyzer {
.count()

// Set rows with value count 1 to true, and otherwise false
val fullColumn: Column = count(UNIQUENESS_ID).over(Window.partitionBy(columnsToGroupBy: _*))
val fullColumn: Column = {
val window = Window.partitionBy(columnsToGroupBy: _*)
where.map {
condition => {
count(when(expr(condition), UNIQUENESS_ID)).over(window)
}
}.getOrElse(count(UNIQUENESS_ID).over(window))
}

FrequenciesAndNumRows(frequencies, numRows, Option(fullColumn))
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/amazon/deequ/analyzers/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ case class Histogram(
}
}

override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = {
override def computeStateFrom(data: DataFrame,
filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = {

// TODO figure out a way to pass this in if its known before hand
val totalCount = if (computeFrequenciesAsRatio) {
Expand Down
26 changes: 23 additions & 3 deletions src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, count, lit, sum}
import org.apache.spark.sql.types.DoubleType

case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns)
with FilterableAnalyzer {

Expand All @@ -34,11 +38,27 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column] = None): DoubleMetric = {
val numUniqueValues = result.getDouble(offset)
val numDistinctValues = result.getLong(offset + 1).toDouble
val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness))
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}
toSuccessMetric(numUniqueValues / numDistinctValues, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object UniqueValueRatio {
Expand Down
29 changes: 25 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,52 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.DoubleType

/** Uniqueness is the fraction of unique values of a column(s), i.e.,
* values that occur exactly once. */
case class Uniqueness(columns: Seq[String], where: Option[String] = None)
case class Uniqueness(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns)
with FilterableAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = {
val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
super.fromAggregationResult(result, offset, Option(fullColumnUniqueness))
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}
super.fromAggregationResult(result, offset, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object Uniqueness {
Expand Down
Loading

0 comments on commit 5b818be

Please sign in to comment.