org.apache.mahout.math.WeightedVector类的使用及代码示例

x33g5p2x  于2022-02-03 转载在 其他  
字(6.8k)|赞(0)|评价(0)|浏览(148)

本文整理了Java中org.apache.mahout.math.WeightedVector类的一些代码示例,展示了WeightedVector类的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。WeightedVector类的具体详情如下:
包路径:org.apache.mahout.math.WeightedVector
类名称:WeightedVector

WeightedVector介绍

[英]Decorates a vector with a floating point weight and an index.
[中]用浮点权重和索引装饰向量。

代码示例

代码示例来源:origin: apache/mahout

  1. public static WeightedVector project(Vector v, Vector projection, int index) {
  2. return new WeightedVector(v, projection, index);
  3. }

代码示例来源:origin: apache/mahout

  1. public Centroid(WeightedVector original) {
  2. super(original.getVector().like().assign(original), original.getWeight(), original.getIndex());
  3. }

代码示例来源:origin: apache/mahout

  1. public static WeightedVector project(Vector v, Vector projection) {
  2. return project(v, projection, INVALID_INDEX);
  3. }

代码示例来源:origin: apache/mahout

  1. @Override
  2. public Vector like() {
  3. return new WeightedVector(getVector().like(), weight, index);
  4. }

代码示例来源:origin: apache/mahout

  1. @Override
  2. public int compare(WeightedVector a, WeightedVector b) {
  3. if (a == b) {
  4. return 0;
  5. }
  6. double aWeight = a.getWeight();
  7. double bWeight = b.getWeight();
  8. int r = Double.compare(aWeight, bWeight);
  9. if (r != 0 && Math.abs(aWeight - bWeight) >= DOUBLE_EQUALITY_ERROR) {
  10. return r;
  11. }
  12. double diff = a.minus(b).norm(1);
  13. if (diff < 1.0e-12) {
  14. return 0;
  15. }
  16. for (Vector.Element element : a.all()) {
  17. r = Double.compare(element.get(), b.get(element.index()));
  18. if (r != 0) {
  19. return r;
  20. }
  21. }
  22. return 0;
  23. }

代码示例来源:origin: apache/mahout

  1. @Test
  2. public void testProjection() {
  3. Vector v1 = new DenseVector(10).assign(Functions.random());
  4. WeightedVector v2 = new WeightedVector(v1, v1, 31);
  5. assertEquals(v1.dot(v1), v2.getWeight(), 1.0e-13);
  6. assertEquals(31, v2.getIndex());
  7. Matrix y = new DenseMatrix(10, 4).assign(Functions.random());
  8. Matrix q = new QRDecomposition(y.viewPart(0, 10, 0, 3)).getQ();
  9. Vector nullSpace = y.viewColumn(3).minus(q.times(q.transpose().times(y.viewColumn(3))));
  10. WeightedVector v3 = new WeightedVector(q.viewColumn(0).plus(q.viewColumn(1)), nullSpace, 1);
  11. assertEquals(0, v3.getWeight(), 1.0e-13);
  12. Vector qx = q.viewColumn(0).plus(q.viewColumn(1)).normalize();
  13. WeightedVector v4 = new WeightedVector(qx, q.viewColumn(0), 2);
  14. assertEquals(Math.sqrt(0.5), v4.getWeight(), 1.0e-13);
  15. WeightedVector v5 = WeightedVector.project(q.viewColumn(0), qx);
  16. assertEquals(Math.sqrt(0.5), v5.getWeight(), 1.0e-13);
  17. }

代码示例来源:origin: org.apache.mahout/mahout-core

  1. int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex();
  2. double closestDistance = closestPair.getWeight();
  3. closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight());

代码示例来源:origin: apache/mahout

  1. public static Centroid create(int key, Vector initialValue) {
  2. if (initialValue instanceof WeightedVector) {
  3. return new Centroid(key, new DenseVector(initialValue), ((WeightedVector) initialValue).getWeight());
  4. } else {
  5. return new Centroid(key, new DenseVector(initialValue), 1);
  6. }
  7. }

代码示例来源:origin: tdunning/anomaly-detection

  1. for (int i = 0; i < SAMPLES; i++) {
  2. int offset = i * STEP;
  3. WeightedVector row = new WeightedVector(new DenseVector(WINDOW), 1, i);
  4. row.assign(trace.viewPart(offset, WINDOW));
  5. row.assign(window, Functions.MULT);
  6. row.assign(Functions.mult(1 / row.norm(2)));
  7. r.add(row);
  8. for (int i = 0; i + WINDOW < trace.size(); i += WINDOW / 2) {
  9. WeightedVector row = new WeightedVector(new DenseVector(WINDOW), 1, i);
  10. row.assign(trace.viewPart(i, WINDOW));
  11. row.assign(window, Functions.MULT);
  12. double scale = row.norm(2);
  13. row.assign(Functions.mult(1 / scale));
  14. out.format("%.3f\t%.3f\t%d\n", rx.get(j, 0), rx.get(j, 1), ((WeightedVector) cluster.getValue()).getIndex());

代码示例来源:origin: org.apache.mahout/mahout-mrlegacy

  1. Centroid c_1 = new Centroid(datapoints.get(selected).clone());
  2. c_1.setIndex(0);
  3. double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight());
  4. seedSelector.set(i, w);

代码示例来源:origin: apache/mahout

  1. @Test
  2. public void testOrdering() {
  3. WeightedVector v1 = new WeightedVector(new DenseVector(new double[]{1, 2, 3}), 5.41, 31);
  4. WeightedVector v2 = new WeightedVector(new DenseVector(new double[]{1, 2, 3}), 5.00, 31);
  5. WeightedVector v3 = new WeightedVector(new DenseVector(new double[]{1, 3, 3}), 5.00, 31);
  6. WeightedVector v4 = v1.clone();
  7. WeightedVectorComparator comparator = new WeightedVectorComparator();
  8. assertTrue(comparator.compare(v1, v2) > 0);
  9. assertTrue(comparator.compare(v3, v1) < 0);
  10. assertTrue(comparator.compare(v3, v2) > 0);
  11. assertEquals(0, comparator.compare(v4, v1));
  12. assertEquals(0, comparator.compare(v1, v1));
  13. }

代码示例来源:origin: org.apache.mahout/mahout-core

  1. for (WeightedVector testDatapoint : trainTestSplit.getSecond()) {
  2. WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue();
  3. closest.setWeight(closest.getWeight() + testDatapoint.getWeight());

代码示例来源:origin: org.apache.mahout/mahout-mrlegacy

  1. private static OnlineSummarizer evaluateStrategy(Matrix testData, BruteSearch ref,
  2. LocalitySensitiveHashSearch cut) {
  3. OnlineSummarizer t1 = new OnlineSummarizer();
  4. for (int i = 0; i < 100; i++) {
  5. final Vector q = testData.viewRow(i);
  6. List<WeightedThing<Vector>> v1 = cut.search(q, 150);
  7. BitSet b1 = new BitSet();
  8. for (WeightedThing<Vector> v : v1) {
  9. b1.set(((WeightedVector)v.getValue()).getIndex());
  10. }
  11. List<WeightedThing<Vector>> v2 = ref.search(q, 100);
  12. BitSet b2 = new BitSet();
  13. for (WeightedThing<Vector> v : v2) {
  14. b2.set(((WeightedVector)v.getValue()).getIndex());
  15. }
  16. b1.and(b2);
  17. t1.add(b1.cardinality());
  18. }
  19. return t1;
  20. }

代码示例来源:origin: apache/mahout

  1. @Override
  2. public String toString() {
  3. return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
  4. }

代码示例来源:origin: org.apache.mahout/mahout-mrlegacy

  1. @Override
  2. public int hashCode() {
  3. int result = super.hashCode();
  4. result = 31 * result + (int) (hash ^ (hash >>> 32));
  5. return result;
  6. }
  7. }

代码示例来源:origin: org.apache.mahout/mahout-mr

  1. int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex();
  2. double closestDistance = closestPair.getWeight();
  3. closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight());

代码示例来源:origin: org.apache.mahout/mahout-mr

  1. /**
  2. * Computes the total weight of the points in the given Vector iterable.
  3. * @param data iterable of points
  4. * @return total weight
  5. */
  6. public static double totalWeight(Iterable<? extends Vector> data) {
  7. double sum = 0;
  8. for (Vector row : data) {
  9. Preconditions.checkNotNull(row);
  10. if (row instanceof WeightedVector) {
  11. sum += ((WeightedVector)row).getWeight();
  12. } else {
  13. sum++;
  14. }
  15. }
  16. return sum;
  17. }
  18. }

代码示例来源:origin: org.apache.mahout/mahout-math

  1. @Override
  2. public Vector like() {
  3. return new WeightedVector(getVector().like(), weight, index);
  4. }

代码示例来源:origin: org.apache.mahout/mahout-mr

  1. Centroid c_1 = new Centroid(datapoints.get(selected).clone());
  2. c_1.setIndex(0);
  3. double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight());
  4. seedSelector.set(i, w);

代码示例来源:origin: org.apache.mahout/mahout-math

  1. @Override
  2. public int compare(WeightedVector a, WeightedVector b) {
  3. if (a == b) {
  4. return 0;
  5. }
  6. double aWeight = a.getWeight();
  7. double bWeight = b.getWeight();
  8. int r = Double.compare(aWeight, bWeight);
  9. if (r != 0 && Math.abs(aWeight - bWeight) >= DOUBLE_EQUALITY_ERROR) {
  10. return r;
  11. }
  12. double diff = a.minus(b).norm(1);
  13. if (diff < 1.0e-12) {
  14. return 0;
  15. }
  16. for (Vector.Element element : a.all()) {
  17. r = Double.compare(element.get(), b.get(element.index()));
  18. if (r != 0) {
  19. return r;
  20. }
  21. }
  22. return 0;
  23. }

相关文章