Skip to content

Commit

Permalink
Implemented JsonDiscriminator feature which allowes to derive JsonReader
Browse files Browse the repository at this point in the history
  • Loading branch information
Георгий Ковалев committed Apr 25, 2024
1 parent d7b1494 commit 8caac38
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 14 deletions.
Binary file added modules/core/.DS_Store
Binary file not shown.
9 changes: 9 additions & 0 deletions modules/core/src/main/scala-3/tethys/JsonDiscriminator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package tethys


trait JsonDiscriminator[A, B]:
def choose: A => B

object JsonDiscriminator:
def by[A, B](f: A => B): JsonDiscriminator[A, B] = new JsonDiscriminator[A, B]:
override def choose: A => B = f
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package tethys.derivation

private[tethys]
object Discriminator:

inline def getLabel[Type, Discriminator]: String =
${ DiscriminatorMacro.getLabel[Type, Discriminator] }

inline def getValue[Type, SubType, Discriminator](label: String): Any =
${ DiscriminatorMacro.getValue[Type, SubType, Discriminator]('{ label }) }

private[derivation]
object DiscriminatorMacro:
import scala.quoted.*

def getLabel[T: Type, D: Type](using quotes: Quotes): Expr[String] =
import quotes.reflect.*
val tpe = TypeRepr.of[T]
val selectorTpe = TypeRepr.of[D]
val symbol = tpe.typeSymbol.fieldMembers
.find(tpe.memberType(_) =:= selectorTpe)
.getOrElse(report.errorAndAbort(s"Selector of type ${selectorTpe.show(using Printer.TypeReprShortCode)} not found in ${tpe.show(using Printer.TypeReprShortCode)}"))

tpe.typeSymbol.children
.find(child => child.caseFields.contains(symbol.overridingSymbol(child)))
.foreach { child =>
report.errorAndAbort(s"Overriding discriminator field '${symbol.name}' in ${child.typeRef.show(using Printer.TypeReprShortCode)} is prohibited")
}

Expr(symbol.name)


def getValue[T: Type, ST: Type, D: Type](label: Expr[String])(using quotes: Quotes): Expr[Any] =
import quotes.reflect.*
val tpe = TypeRepr.of[T]
val selectorTpe = TypeRepr.of[D]
val symbol = tpe.typeSymbol.fieldMembers
.find(tpe.memberType(_) =:= selectorTpe)
.getOrElse(report.errorAndAbort(s"Selector of type ${selectorTpe.show(using Printer.TypeReprShortCode)} not found in ${tpe.show(using Printer.TypeReprShortCode)}"))

Select(stub[ST].asTerm, symbol).asExprOf[Any]


def stub[T: Type](using quotes: Quotes): Expr[T] =
import quotes.reflect.*
val tpe = TypeRepr.of[T]
val symbol = tpe.typeSymbol
val constructorFieldsFilledWithNulls: List[List[Term]] =
symbol.primaryConstructor.paramSymss
.filterNot(_.exists(_.isType))
.map(_.map(_.typeRef.widen match {
case t@AppliedType(inner, applied) =>
Select.unique('{ null }.asTerm, "asInstanceOf").appliedToTypes(List(inner.appliedTo(tpe.typeArgs)))
case other =>
Select.unique('{ null }.asTerm, "asInstanceOf").appliedToTypes(List(other))
}))

New(TypeTree.ref(symbol)).select(symbol.primaryConstructor)
.appliedToTypes(symbol.typeRef.typeArgs.map(_ => TypeRepr.of[Null]))
.appliedToArgss(constructorFieldsFilledWithNulls)
.asExprOf[T]
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@ package tethys.derivation

import tethys.readers.{FieldName, ReaderError}
import tethys.readers.tokens.TokenIterator
import tethys.JsonReader
import tethys.{JsonDiscriminator, JsonReader}

import scala.deriving.Mirror
import scala.compiletime.{erasedValue, summonInline, constValue, constValueTuple, summonFrom}
import scala.compiletime.{constValue, constValueTuple, erasedValue, summonFrom, summonInline}


private [tethys] trait JsonReaderDerivation:
inline def derived[A](using mirror: Mirror.ProductOf[A]): JsonReader[A] =
inline def derived[A](using mirror: Mirror.Of[A]): JsonReader[A] =
inline mirror match
case mirror: Mirror.ProductOf[A] => deriveProductJsonReader[A](using mirror)
case mirror: Mirror.SumOf[A] => deriveSumJsonReader[A](using mirror)

private inline def deriveProductJsonReader[A](using mirror: Mirror.ProductOf[A]) =
new JsonReader[A]:
override def read(it: TokenIterator)(implicit fieldName: FieldName) =
if !it.currentToken().isObjectStart then
ReaderError.wrongJson("Expected object start but found: " + it.currentToken().toString)
else
it.nextToken()
val labels = constValueTuple[mirror.MirroredElemLabels].toArray.collect { case s: String => s }
val readersByLabels = labels.zip(summonJsonReaders[A, mirror.MirroredElemTypes].zipWithIndex).toMap
val readersByLabels = labels.zip(summonJsonReadersForProduct[A, mirror.MirroredElemTypes].zipWithIndex).toMap
val defaults = getOptionsByIndex[mirror.MirroredElemTypes]().toMap ++ Defaults.collectFrom[A]
val optionalLabels = defaults.keys.map(labels(_))

Expand All @@ -33,7 +38,7 @@ private [tethys] trait JsonReaderDerivation:
collectedValues += idx -> value
missingFields -= jsonName
}

it.nextToken()

if (missingFields.nonEmpty)
Expand All @@ -48,28 +53,70 @@ private [tethys] trait JsonReaderDerivation:
case that: Product if that.productArity == productArity => true
case _ => false

private inline def summonJsonReaders[T, Elems <: Tuple]: List[JsonReader[?]] =

private inline def deriveSumJsonReader[A](using mirror: Mirror.SumOf[A]): JsonReader[A] =
summonFrom[JsonDiscriminator[A, _]] {
case discriminator: JsonDiscriminator[A, discriminator] =>
val label = Discriminator.getLabel[A, discriminator]
val readersByDiscriminator = summonDiscriminators[A, discriminator, mirror.MirroredElemTypes](label)
.zip(summonJsonReadersForSum[A, mirror.MirroredElemTypes])
.toMap

JsonReader.builder
.addField[discriminator](label, summonInline[JsonReader[discriminator]])
.selectReader[A] { discriminator =>
readersByDiscriminator.getOrElse(
discriminator,
ReaderError.wrongJson(s"Unexpected discriminator found: $discriminator")(using FieldName(label))
).asInstanceOf[JsonReader[A]]
}
case _ =>
scala.compiletime.error("JsonDiscriminator is required to derive JsonReader for sum type")
}

private inline def summonDiscriminators[T, Discriminator, Elems <: Tuple](label: String, idx: Int = 0): List[Any] =
inline erasedValue[Elems] match
case _: EmptyTuple =>
Nil
case _: (elem *: elems) =>
Discriminator.getValue[T, elem, Discriminator](label) :: summonDiscriminators[T, Discriminator, elems](label, idx + 1)


private inline def summonJsonReadersForProduct[T, Elems <: Tuple]: List[JsonReader[?]] =
inline erasedValue[Elems] match
case _: EmptyTuple =>
Nil
case _: (elem *: elems) =>
deriveOrSummon[T, elem] :: summonJsonReaders[T, elems]
summonOrDeriveJsonReaderForProduct[T, elem] :: summonJsonReadersForProduct[T, elems]


private inline def deriveOrSummon[T, Elem]: JsonReader[Elem] =
private inline def summonJsonReadersForSum[T, Elems <: Tuple]: List[JsonReader[?]] =
inline erasedValue[Elems] match
case _: EmptyTuple =>
Nil
case _: (elem *: elems) =>
summonOrDeriveJsonReaderForSum[T, elem] :: summonJsonReadersForSum[T, elems]

private inline def summonOrDeriveJsonReaderForProduct[T, Elem]: JsonReader[Elem] =
inline erasedValue[Elem] match
case _: T =>
deriveRec[T, Elem]
case _ =>
summonInline[JsonReader[Elem]]

private inline def summonOrDeriveJsonReaderForSum[T, Elem]: JsonReader[Elem] =
summonFrom[JsonReader[Elem]] {
case reader: JsonReader[Elem] => reader
case _ => deriveRec[T, Elem]
}

private inline def deriveRec[T, Elem]: JsonReader[Elem] =
inline erasedValue[T] match
case _: Elem =>
scala.compiletime.error("Recursive derivation is not possible")
case _ =>
JsonReader.derived[Elem](using summonInline[Mirror.ProductOf[Elem]])


private inline def getOptionsByIndex[Elems <: Tuple](idx: Int = 0): List[(Int, None.type)] =
inline erasedValue[Elems] match
case _: EmptyTuple =>
Expand All @@ -78,5 +125,3 @@ private [tethys] trait JsonReaderDerivation:
idx -> None :: getOptionsByIndex[elems](idx + 1)
case _: (_ *: elems) =>
getOptionsByIndex[elems](idx + 1)


Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package tethys.derivation

import tethys.{JsonObjectWriter, JsonWriter}
import tethys.{JsonDiscriminator, JsonObjectWriter, JsonWriter}
import tethys.writers.tokens.TokenWriter

import scala.deriving.Mirror
import scala.compiletime.{summonInline, erasedValue, summonFrom}
import scala.compiletime.{erasedValue, summonFrom, summonInline, constValueTuple}

private[tethys] trait JsonWriterDerivation:
inline def derived[A](using mirror: Mirror.Of[A]): JsonObjectWriter[A] =
Expand All @@ -21,9 +21,24 @@ private[tethys] trait JsonWriterDerivation:
}

case m: Mirror.SumOf[A] =>
writeDiscriminatorIfProvided[A, m.MirroredElemTypes, m.MirroredElemLabels](value, tokenWriter)

summonJsonWritersForSum[A, m.MirroredElemTypes](m.ordinal(value))
.writeValues(value.asInstanceOf, tokenWriter)

private inline def writeDiscriminatorIfProvided[T, Elems <: Tuple, Labels <: Tuple](value: T, tokenWriter: TokenWriter): Unit =
summonFrom[JsonDiscriminator[T, _]] {
case discriminator: JsonDiscriminator[T, t] =>
summonInline[JsonWriter[t]]
.write(
name = Discriminator.getLabel[T, t],
value = discriminator.choose(value).asInstanceOf[t],
tokenWriter = tokenWriter
)

case _ =>
}

private inline def summonJsonWritersForSum[T, Elems <: Tuple]: List[JsonObjectWriter[?]] =
inline erasedValue[Elems] match
case _: EmptyTuple =>
Expand Down
77 changes: 76 additions & 1 deletion modules/core/src/test/scala-3/tethys/DerivationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import tethys.commons.TokenNode.obj
import tethys.commons.{Token, TokenNode}
import tethys.readers.tokens.QueueIterator
import tethys.writers.tokens.SimpleTokenWriter.SimpleTokenWriterOps
import tethys.derivation.Defaults

class DerivationSpec extends AnyFlatSpec with Matchers {
def read[A: JsonReader](nodes: List[TokenNode]): A = {
Expand Down Expand Up @@ -128,4 +127,80 @@ class DerivationSpec extends AnyFlatSpec with Matchers {
read[WithArg[Int]](obj("x" -> 5)) shouldBe WithArg[Int](5)
read[WithArg[String]](obj("x" -> 5, "y" -> "lool")) shouldBe WithArg[String](5, Some("lool"))
}

it should "write/read sum types with provided json discriminator" in {
enum Disc derives StringEnumJsonWriter, StringEnumJsonReader:
case A, B

sealed trait Choose(val discriminator: Disc) derives JsonWriter, JsonReader

object Choose:
given JsonDiscriminator[Choose, Disc] = JsonDiscriminator.by(_.discriminator)

case class AA() extends Choose(Disc.A)
case class BB() extends Choose(Disc.B)

(Choose.AA(): Choose).asTokenList shouldBe obj("discriminator" -> "A")
(Choose.BB(): Choose).asTokenList shouldBe obj("discriminator" -> "B")

read[Choose](obj("discriminator" -> "A")) shouldBe Choose.AA()
read[Choose](obj("discriminator" -> "B")) shouldBe Choose.BB()
}

it should "write/read sum types with provided json discriminator of simple type" in {
sealed trait Choose(val discriminator: Int) derives JsonWriter, JsonReader

object Choose:
given JsonDiscriminator[Choose, Int] = JsonDiscriminator.by(_.discriminator)

case class AA() extends Choose(0)

case class BB() extends Choose(1)

(Choose.AA(): Choose).asTokenList shouldBe obj("discriminator" -> 0)
(Choose.BB(): Choose).asTokenList shouldBe obj("discriminator" -> 1)

read[Choose](obj("discriminator" -> 0)) shouldBe Choose.AA()
read[Choose](obj("discriminator" -> 1)) shouldBe Choose.BB()
}

it should "write/read json for generic discriminators" in {
enum Disc1 derives StringEnumJsonWriter, StringEnumJsonReader:
case A, B

enum Disc2 derives StringEnumJsonWriter, StringEnumJsonReader:
case AA, BB

sealed trait Choose[A](val discriminator: A) derives JsonWriter, JsonReader

object Choose:
given [A]: JsonDiscriminator[Choose[A], A] = JsonDiscriminator.by(_.discriminator)

case class ChooseA() extends Choose[Disc1](Disc1.A)
case class ChooseB() extends Choose[Disc2](Disc2.BB)

(ChooseA(): Choose[Disc1]).asTokenList shouldBe obj("discriminator" -> "A")
(ChooseB(): Choose[Disc2]).asTokenList shouldBe obj("discriminator" -> "BB")

read[Choose[Disc1]](obj("discriminator" -> "A")) shouldBe ChooseA()
read[Choose[Disc2]](obj("discriminator" -> "BB")) shouldBe ChooseB()
}

it should "not compile derivation when discriminator override found" in {

"""
|
| sealed trait Foo(val x: Int) derives JsonReader, JsonWriter
|
| object Foo:
| given JsonDiscriminator[Foo, Int] = JsonDiscriminator.by(_.x)
|
| case class Bar(override val x: Int) extends Foo(x)
|
| case class Baz() extends Foo(0)
|
|""" shouldNot compile


}
}

0 comments on commit 8caac38

Please sign in to comment.