使用子类型进行Scala 3集合分区

mfuanj7w  于 2022-11-09  发布在  Scala

在Scala 3中,假设我有一个List[Try[String]]。我可以将它分为成功和失败,这样每个列表都有合适的子类型吗?

import scala.util.{Try, Success, Failure}
val tries = List(Success("1"), Failure(Exception("2")))
val (successes, failures) = tries.partition(_.isSuccess)


val successes = tries.filter(_.isInstanceOf[Success[String]])




@LuisMiguelMejíasuárez ok这里的诀窍是,Try有一个toEither方法,它可以拆分成正确的类型。如果我们有一个正常的密封特征呢?
在Scala 2中,我会这样做

import shapeless.{:+:, ::, CNil, Coproduct, Generic, HList, HNil, Inl, Inr, Poly0}
import shapeless.ops.coproduct.ToHList
import shapeless.ops.hlist.{FillWith, Mapped, Tupler}

trait Loop[C <: Coproduct, L <: HList] {
  def apply(c: C, l: L): L
object Loop {
  implicit def recur[H, CT <: Coproduct, HT <: HList](implicit
    loop: Loop[CT, HT]
  ): Loop[H :+: CT, List[H] :: HT] = {
    case (Inl(h), hs :: ht) => (h :: hs) :: ht
    case (Inr(ct), hs :: ht) => hs :: loop(ct, ht)

  implicit val base: Loop[CNil, HNil] = (_, l) => l

object nilPoly extends Poly0 {
  implicit def cse[A]: Case0[List[A]] = at(Nil)

def partition[A, C <: Coproduct, L <: HList, L1 <: HList](as: List[A])(implicit
  generic: Generic.Aux[A, C],
  toHList: ToHList.Aux[C, L],
  mapped: Mapped.Aux[L, List, L1],
  loop: Loop[C, L1],
  fillWith: FillWith[nilPoly.type, L1],
  tupler: Tupler[L1]
): tupler.Out = {
  val partitionHList: L1 = as.foldRight(fillWith())((a, l1) =>
    loop(generic.to(a), l1)


sealed trait A
case class B(i: Int) extends A
case class C(i: Int) extends A
case class D(i: Int) extends A

partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))) 
// (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2))): (List[B], List[C], List[D])


import scala.annotation.tailrec
import scala.deriving.Mirror

object App1 {
  // ============= Generic =====================
  trait Generic[T] {
    type Repr
    def to(t: T): Repr
    def from(r: Repr): T
  object Generic {
    type Aux[T, Repr0] = Generic[T] { type Repr = Repr0 }
    def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
      new Generic[T] {
        override type Repr = Repr0
        override def to(t: T): Repr0 = f(t)
        override def from(r: Repr0): T = g(r)

    object ops {
      extension [A](a: A) {
        def toRepr(using g: Generic[A]): g.Repr = g.to(a)

      extension [Repr](a: Repr) {
        def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)

    given [T <: Product](using
      m: Mirror.ProductOf[T]
    ): Aux[T, m.MirroredElemTypes] = instance(
       .foldRight[Tuple](EmptyTuple)(_ *: _)

    inline given [T, C <: Coproduct](using
      m: Mirror.SumOf[T],
      ev: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
    ): Generic.Aux[T, C] =
        matchExpr[T, C](_).asInstanceOf[C],

    import scala.quoted.*

    inline def matchExpr[T, C <: Coproduct](ident: T): Coproduct =
      ${matchExprImpl[T, C]('ident)}

    def matchExprImpl[T: Type, C <: Coproduct : Type](
      ident: Expr[T]
    )(using Quotes): Expr[Coproduct] = {
      import quotes.reflect.*

      def unwrapCoproduct(typeRepr: TypeRepr): List[TypeRepr] = typeRepr match {
        case AppliedType(_, List(typ1, typ2)) => typ1 :: unwrapCoproduct(typ2)
        case _  => Nil

      val typeReprs = unwrapCoproduct(TypeRepr.of[C])

      val methodIdent =
        Ident(TermRef(TypeRepr.of[Coproduct.type], "unsafeToCoproduct"))

      def caseDefs(ident: Term): List[CaseDef] =
        typeReprs.zipWithIndex.map { (typeRepr, i) =>
            Typed(ident, Inferred(typeRepr) /*TypeIdent(typeRepr.typeSymbol)*/),
                List(Literal(IntConstant(i)), ident)

      def matchTerm(ident: Term): Term = Match(ident, caseDefs(ident))


  // ============= Coproduct =====================
  sealed trait Coproduct extends Product with Serializable
  sealed trait +:[+H, +T <: Coproduct] extends Coproduct
  final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
  final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
  sealed trait CNil extends Coproduct

  object Coproduct {
    def unsafeToCoproduct(length: Int, value: Any): Coproduct =
      (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))

    def unsafeFromCoproduct(c: Coproduct): Any = c match {
      case Inl(h) => h
      case Inr(c) => unsafeFromCoproduct(c)
      case _: CNil => sys.error("impossible")

    type ToCoproduct[T <: Tuple] <: Coproduct = T match {
      case EmptyTuple => CNil
      case h *: t => h +: ToCoproduct[t]

//    type ToTuple[C <: Coproduct] <: Tuple = C match {
//      case CNil => EmptyTuple
//      case h +: t => h *: ToTuple[t]
//    }

    trait ToTuple[C <: Coproduct] {
      type Out <: Tuple
    object ToTuple {
      type Aux[C <: Coproduct, Out0 <: Tuple] = ToTuple[C] { type Out = Out0 }
      def instance[C <: Coproduct, Out0 <: Tuple]: Aux[C, Out0] =
        new ToTuple[C] { override type Out = Out0 }

      given [H, T <: Coproduct](using 
        toTuple: ToTuple[T]
      ): Aux[H +: T, H *: toTuple.Out] = instance
      given Aux[CNil, EmptyTuple] = instance

// different file
import App1.{+:, CNil, Coproduct, Generic, Inl, Inr}

object App2 {    
  trait Loop[C <: Coproduct, L <: Tuple] {
    def apply(c: C, l: L): L
  object Loop {
    given [H, CT <: Coproduct, HT <: Tuple](using 
      loop: Loop[CT, HT]
    ): Loop[H +: CT, List[H] *: HT] = {
      case (Inl(h), hs *: ht) => (h :: hs) *: ht
      case (Inr(ct), hs *: ht) => hs *: loop(ct, ht)

    given Loop[CNil, EmptyTuple] = (_, l) => l

  trait FillWithNil[L <: Tuple] {
    def apply(): L
  object FillWithNil {
    given [H, T <: Tuple](using 
      fillWithNil: FillWithNil[T]
    ): FillWithNil[List[H] *: T] = () => Nil *: fillWithNil()
    given FillWithNil[EmptyTuple] = () => EmptyTuple

  def partition[A, /*L <: Tuple,*/ L1 <: Tuple](as: List[A])(using
    generic: Generic.Aux[A, _ <: Coproduct],
    toTuple: Coproduct.ToTuple[generic.Repr],
    //ev0: Coproduct.ToTuple[generic.Repr] =:= L, // compile-time NPE
    ev: Tuple.Map[toTuple.Out/*L*/, List] =:= L1,
    loop: Loop[generic.Repr, L1],
    fillWith: FillWithNil[L1]
  ): L1 = as.foldRight(fillWith())((a, l1) => loop(generic.to(a), l1))

  sealed trait A
  case class B(i: Int) extends A
  case class C(i: Int) extends A
  case class D(i: Int) extends A

  def main(args: Array[String]): Unit = {
    println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
  // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))

Scala 3.0.2

import scala.deriving.Mirror
import scala.util.NotGiven

trait Generic[T] {
  type Repr
  def to(t: T): Repr
  def from(r: Repr): T

object Generic {
  type Aux[T, Repr0] = Generic[T] {type Repr = Repr0}

  def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
    new Generic[T] {
      override type Repr = Repr0
      override def to(t: T): Repr0 = f(t)
      override def from(r: Repr0): T = g(r)

  object ops {
    extension[A] (a: A) {
      def toRepr(using g: Generic[A]): g.Repr = g.to(a)

    extension[Repr] (a: Repr) {
      def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)

  given [T <: Product](using
    // ev: NotGiven[T <:< Tuple],
    // ev1: NotGiven[T <:< Coproduct],
    m: Mirror.ProductOf[T],
    m1: Mirror.ProductOf[m.MirroredElemTypes]
  ): Aux[T, m.MirroredElemTypes] = instance(

  given[T, C <: Coproduct](using
    // ev: NotGiven[T <:< Tuple],
    // ev1: NotGiven[T <:< Coproduct],
    m: Mirror.SumOf[T],
    ev2: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
  ): Generic.Aux[T, C/*Coproduct.ToCoproduct[m.MirroredElemTypes]*/] = {
      t => Coproduct.unsafeToCoproduct(m.ordinal(t), t).asInstanceOf[C],
sealed trait Coproduct extends Product with Serializable
sealed trait +:[+H, +T <: Coproduct] extends Coproduct
final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
sealed trait CNil extends Coproduct

object Coproduct {
  def unsafeToCoproduct(length: Int, value: Any): Coproduct =
    (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))

  def unsafeFromCoproduct(c: Coproduct): Any = c match {
    case Inl(h) => h
    case Inr(c) => unsafeFromCoproduct(c)
    case _: CNil => sys.error("impossible")

  type ToCoproduct[T <: Tuple] <: Coproduct = T match {
    case EmptyTuple => CNil
    case h *: t => h +: ToCoproduct[t]

  type ToTuple[C <: Coproduct] <: Tuple = C match {
    case CNil => EmptyTuple
    case h +: t => h *: ToTuple[t]


import scala.compiletime.erasedValue

inline def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = (inline erasedValue[C] match {
  case _: CNil => inline erasedValue[L] match {
    case _: EmptyTuple => EmptyTuple
  case _: (h +: ct) => inline erasedValue[L] match {
    case _: (List[`h`] *: ht) => (c, l) match {
      case (Inl(h_v: `h`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
        (h_v :: hs_v) *: ht_v
      case (Inr(ct_v: `ct`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
        hs_v *: loop[ct, ht](ct_v, ht_v)

inline def fillWithNil[L <: Tuple]: L = (inline erasedValue[L] match {
  case _: EmptyTuple => EmptyTuple
  case _: (List[h] *: t) => Nil *: fillWithNil[t]

type TupleList[C <: Coproduct] = Tuple.Map[Coproduct.ToTuple[C], List]

inline def partition[A](as: List[A])(using
  generic: Generic.Aux[A, _ <: Coproduct]
): TupleList[generic.Repr] =
  as.foldRight(fillWithNil[TupleList[generic.Repr]])((a, l1) => loop(generic.to(a), l1))
sealed trait A
case class B(i: Int) extends A
case class C(i: Int) extends A
case class D(i: Int) extends A

@main def test = {
  println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
  // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))

在3.2.0 https://scastie.scala-lang.org/DmytroMitin/940QaiqDQQ2QegCyxTbEIQ/1中测试
How to access parameter list of case class in a dotty macro

//Loop[C, L] = L
type Loop[C <: Coproduct, L <: Tuple] <: Tuple = C match {
  case CNil    => CNilLoop[L]
  case h +: ct => CConsLoop[h, ct, L]
// match types seem not to support nested type matching
type CNilLoop[L <: Tuple] <: Tuple = L match {
  case EmptyTuple => EmptyTuple
type CConsLoop[H, CT <: Coproduct, L <: Tuple] <: Tuple = L match {
  case List[H] *: ht => List[H] *: Loop[CT, ht]
/*inline*/ def loop0[C <: Coproduct, L <: Tuple](c: C, l: L): Loop[C, L] = /*inline*/ c match {
  case _: CNil => /*inline*/ l match {
    case _: EmptyTuple => EmptyTuple
  case c: (h +: ct) => /*inline*/ l match {
    case l: (List[`h`] *: ht) => (c, l) match {
      case (Inl(h_v/*: `h`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) =>
        (h_v :: hs_v) *: ht_v.asInstanceOf[Loop[ct, ht]]
      case (Inr(ct_v/*: `ct`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) => 
        hs_v *: loop0[ct, ht](ct_v, ht_v)
/*inline*/ def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = loop0(c, l).asInstanceOf[L]

Scala 2的另一个实现:Split list of algebraic date type to lists of branches?
