我有一个函数,我希望它是通用的,但限制它采取某些子类型。为了简单起见,我希望我的函数只处理long、int、float和double。所以我想到的是:
def covariance[A](xElems: Seq[A], yElems: Seq[A]): A = {
val (meanX, meanY) = (mean(xElems), mean(yElems))
val (meanDiffX, meanDiffY) = (meanDiff(meanX, xElems), meanDiff(meanY, yElems))
((meanDiffX zip meanDiffY).map { case (x, y) => x * y }.sum) / xElems.size - 1
}
def mean[A](elems: Seq[A]): A = {
(elems.fold(_ + _)) / elems.length
}
def meanDiff[A](mean: A, elems: Seq[A]) = {
elems.map(elem => elem - mean)
}
以下是我将用于检查上述类型的方法:
import scala.reflect.{ClassTag, classTag}
def matchList2[A : ClassTag](list: List[A]) = list match {
case intlist: List[Int @unchecked] if classTag[A] == classTag[Int] => println("A List of ints!")
case longlist: List[Long @unchecked] if classTag[A] == classTag[Long] => println("A list of longs!")
}
注意,我使用的是classtag。我也可以使用一个typetag,甚至可以使用这个不成形的库。
我想知道这是不是一个好办法?或者我应该使用有界类型来解决我想要的问题?
编辑:基于对使用fractional typeclass的评论和建议,下面是我认为它是如何工作的!
def covariance[A: Fractional](xElems: Seq[A], yElems: Seq[A]): A = {
val (meanX, meanY) = (mean(xElems), mean(yElems))
val (meanDiffX, meanDiffY) = (meanDiff(meanX, xElems), meanDiff(meanY, yElems))
((meanDiffX zip meanDiffY).map { case (x, y) => x * y }.sum) / xElems.size - 1
}
def mean[A](elems: Seq[A]): A = {
(elems.fold(_ + _)) / elems.length
}
def meanDiff[A](mean: A, elems: Seq[A]) = {
elems.map(elem => elem - mean)
}
1条答案
按热度按时间kpbwa7wx1#
根据评论和输入,这里是我想出的!