如何在flink中使用我的模型进行分组

6mzjoqzu  于 2021-06-25  发布在  Flink
关注(0)|答案(2)|浏览(268)

我使用flink和java结合使用我们的逻辑来创建我的推荐系统。
所以我有一个数据集:

[user] [item]
100      1
100      2
100      3
100      4
100      5
200      1
200      2
200      3
200      6
300      1
300      6
400      7

所以我把它们都Map到一个元组:

DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() {
            @Override
            public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception {
                Long customerId = 0L;
                Long itemId = 0L;
                Integer count = 0;

                for (Tuple2<Long, Long> item : iterable) {
                    customerId = item.f0;
                    itemId = item.f1;
                    count = count + 1;
                }

                collector.collect(new Tuple3<>(customerId, itemId, count));
            }
    });

当我在arraylist中获取所有客户和is项目后:

DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() {
            @Override
            public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception {
                ArrayList<Long> newItems = new ArrayList<>();
                Long customerId = 0L;

                for (Tuple3<Long, Long, Integer> item : iterable) {
                    customerId = item.f0;
                    newItems.add(item.f1);
                }

                collector.collect(new CustomerItems(customerId, newItems));
            }
    });

现在我需要找到所有“相似”的客户。但在尝试了很多东西之后,什么都不起作用。
逻辑是:

for ci : CustomerItems
  c1 = c1.customerId

    for ci2 : CustomerItems  
      c2 = ci2.cstomerId

      if c1 != c2
        if c2.getItems() have any item inside c1.getItems()
          collector.collect(new Tuple2<c1, c2>)

我尝试使用reduce,但是我不能在迭代器上迭代两次(循环内循环)。
有人能帮我吗?

oxcyiej7

oxcyiej71#

我解决了问题,但我需要小组和减少后的“交叉”。我不知道这是最好的方法。有人能提些建议吗?
结果如下:

package org.myorg.quickstart;

import org.apache.flink.api.common.functions.CrossFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.Collector;

import java.io.Serializable;
import java.util.ArrayList;

public class UserRecommendation {

    public static void main(String[] args) throws Exception {
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        // le o arquivo cm o dataset
        DataSet<String> text = env.readTextFile("/Users/paulo/Downloads/dataset.csv");

        // cria tuple com: customer | item | count
        DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineFieldSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() {
            @Override
            public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception {
                Long customerId = 0L;
                Long itemId = 0L;
                Integer count = 0;

                for (Tuple2<Long, Long> item : iterable) {
                    customerId = item.f0;
                    itemId = item.f1;
                    count = count + 1;
                }

                collector.collect(new Tuple3<>(customerId, itemId, count));
            }
        });

        // agrupa os items do customer dentro do customer
        final DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() {
            @Override
            public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception {
                ArrayList<Long> newItems = new ArrayList<>();
                Long customerId = 0L;

                for (Tuple3<Long, Long, Integer> item : iterable) {
                    customerId = item.f0;
                    newItems.add(item.f1);
                }

                collector.collect(new CustomerItems(customerId, newItems));
            }
        });

        // obtém todos os itens do customer que pertence a um usuário parecido
        DataSet<CustomerItems> ci = customerItems.cross(customerItems).with(new CrossFunction<CustomerItems, CustomerItems, CustomerItems>() {

            @Override
            public CustomerItems cross(CustomerItems customerItems, CustomerItems customerItems2) throws Exception {
                if (!customerItems.customerId.equals(customerItems2.customerId)) {
                    boolean has = false;

                    for (Long item : customerItems2.items) {
                        if (customerItems.items.contains(item)) {
                            has = true;
                            break;
                        }
                    }

                    if (has) {
                        for (Long item : customerItems2.items) {
                            if (!customerItems.items.contains(item)) {
                                customerItems.ritems.add(item);
                            }
                        }
                    }
                }

                return customerItems;
            }

        }).groupBy(new KeySelector<CustomerItems, Long>() {

            @Override
            public Long getKey(CustomerItems customerItems) throws Exception {
                return customerItems.customerId;
            }

        }).reduceGroup(new GroupReduceFunction<CustomerItems, CustomerItems>() {

            @Override
            public void reduce(Iterable<CustomerItems> iterable, Collector<CustomerItems> collector) throws Exception {
                CustomerItems c = new CustomerItems();

                for (CustomerItems current : iterable) {
                    c.customerId = current.customerId;

                    for (Long item : current.ritems) {
                        if (!c.ritems.contains(item)) {
                            c.ritems.add(item);
                        }
                    }
                }

                collector.collect(c);
            }

        });

        ci.first(100).print();
        System.out.println(ci.count());
    }

    public static class CustomerItems implements Serializable {

        public Long customerId;
        public ArrayList<Long> items = new ArrayList<>();
        public ArrayList<Long> ritems = new ArrayList<>();

        public CustomerItems() {

        }

        public CustomerItems(Long customerId, ArrayList<Long> items) {
            this.customerId = customerId;
            this.items = items;
        }

        @Override
        public String toString() {
            StringBuilder itemsData = new StringBuilder();

            if (items != null) {
                for (Long item : items) {
                    if (itemsData.length() == 0) {
                        itemsData.append(item);
                    } else {
                        itemsData.append(", ").append(item);
                    }
                }
            }

            StringBuilder ritemsData = new StringBuilder();

            if (ritems != null) {
                for (Long item : ritems) {
                    if (ritemsData.length() == 0) {
                        ritemsData.append(item);
                    } else {
                        ritemsData.append(", ").append(item);
                    }
                }
            }

            return String.format("[ID: %d, Items: %s, RItems: %s]", customerId, itemsData, ritemsData);
        }
    }

    public static final class LineFieldSplitter implements FlatMapFunction<String, Tuple2<Long, Long>> {

        @Override
        public void flatMap(String value, Collector<Tuple2<Long, Long>> out) {
            // normalize and split the line
            String[] tokens = value.split("\t");

            if (tokens.length > 1) {
                out.collect(new Tuple2<>(Long.valueOf(tokens[0]), Long.valueOf(tokens[1])));
            }
        }
    }

}

与gist链接:https://gist.github.com/prsolucoes/b406ae98ea24120436954967e37103f6

v9tzhpje

v9tzhpje2#

您可以将数据集与自身交叉,并基本上将逻辑1:1插入交叉函数(不包括2个循环,因为交叉为您做了这件事)。

相关问题