diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index edd4f8e9..8e2e351b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -37,7 +37,7 @@ case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleVa override def metricValue(): Double = state } -case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { +case class CustomSql(expression: String, disambiguator: String = "*") extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data * @@ -76,15 +76,19 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double state match { // The returned state may case Some(theState) => theState.stateOrError match { - case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value)) - case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error))) + case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Success(value)) + case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException(error))) } case None => - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } override private[deequ] def toFailureMetric(failure: Exception) = { - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala index 7e6e96c3..e6e23c40 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala @@ -5,7 +5,7 @@ * use this file except in compliance with the License. A copy of the License * is located at * - * http://aws.amazon.com/apache2.0/ + * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -84,5 +85,21 @@ class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with case Failure(exception) => exception.getMessage should include("foo") } } + + "apply metric disambiguation string to returned metric" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val disambiguator = "statement1" + val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL", disambiguator) + val state = sql.computeStateFrom(data) + val metric: DoubleMetric = sql.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe 6.0 + metric.name shouldBe "CustomSQL" + metric.entity shouldBe Entity.Dataset + metric.instance shouldBe "statement1" + } } }