如何使用mapreduce避免knn程序溢出?

o8x7eapl  于 2021-06-02  发布在  Hadoop
关注(0)|答案(0)|浏览(246)

我在下面编写的程序产生了大量的溢出(溢出量高达数GB,而我的输入和输出数据只有20mb左右)。
我所做的只是将测试文件存储在缓存中,并在每次将一行列车数据传递到 map() 功能。我无法设置 Combiner 这里是因为每个map()产生的结果对我的实现来说毫无意义 N = number of test data 记录它们都不共享同一个键(我使用测试数据的索引作为键)。

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;

public class KnnMapper extends Mapper<LongWritable, Text, LongWritable, DistClassPair> {
    private List<List<Object>> test;
    private List<String> testY;

    private static final double[] MIN = { 28, 1.58, 55, 22, -306, -271, -603, -494, -571, -616,
            -499, -506, -613, -700, -213, -251 };
    private static final double[] MAX = { 75, 1.71, 83, 28.6, 509, 533, 411, 69,
            128, 102, 351, 471, -20, -13, -39, -56 };
    // The map() method is run by MapReduce once for each row supplied as the
    // input data
    public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
        /*
         *  Tokenize the input line (presented as 'value' by MapReduce) from the csv file
         */
        String trainLine = value.toString();
        StringTokenizer st = new StringTokenizer(trainLine, "\t");

        List<Object> trainData = new ArrayList<>();

        String user = st.nextToken();
        String gender = st.nextToken();

        trainData.add(user);
        trainData.add(gender);

        int index = 0;
        while (st.countTokens() > 1) {
            trainData.add(HelperFunc.normalize(st.nextToken(), MIN[index], MAX[index]));
            index++;
        }

        String cls = st.nextToken();

        /*
         * Calculate the distance between each test data and train data,
         * the variable index means the location of test data
         */
        index = 0;
        for (List<Object> testData : test) {
            double dist = HelperFunc.calcDistance(trainData, testData);
            context.write(new LongWritable(index), new DistClassPair(dist, cls));
            index++;
        }
    }

    @Override
    /*
     * Set up the testing data from the cache
     */
    protected void setup(Context context) throws IOException, InterruptedException {
        test = new ArrayList<>();
        testY = new ArrayList<>();

        BufferedReader buff = new BufferedReader(new FileReader(context.getCacheFiles()[0].toString()));
        String line = buff.readLine();

        System.out.println(line);

        while (line != null) {
            StringTokenizer st = new StringTokenizer(line, "\t");
            List<Object> testData = new ArrayList<>();

            String user = st.nextToken();
            String gender = st.nextToken();

            testData.add(user);
            testData.add(gender);

            int index = 0;
            while (st.countTokens() > 1) {
                testData.add(HelperFunc.normalize(st.nextToken(), MIN[index], MAX[index]));
                index++;
            }

            test.add(testData);
            testY.add(st.nextToken());
            line = buff.readLine();
        }

        buff.close();
    }

}

在我的实现中,我使用了一个自定义类 DistClassPair ,它只是将距离和类信息作为值存储。

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

import org.apache.hadoop.io.WritableComparable;

public class DistClassPair implements WritableComparable<DistClassPair> {   
    private Double dist;
    private String cls;

    public DistClassPair(Double dist, String cls) {
        this.dist = dist;
        this.cls = cls;
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        dist = in.readDouble();
        cls = in.readLine();
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeDouble(dist);
        out.writeBytes(cls);
    }

    @Override
    public int compareTo(DistClassPair o) {
        return Double.compare(dist, o.dist);
    }

    public String getCls() {
        return cls;
    }
}

下面是 KnnDriver 如果你愿意,我会写信的。

import java.net.URI;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

public class KnnDriver extends Configured implements Tool{
    /*
     *  args = N, test.csv, train.csv, outputpath
     */
    public static void main(String[] args) throws Exception {
        int res = ToolRunner.run(new Configuration(), new KnnDriver(), args);
        System.exit(res);
    }

    @Override
    public int run(String[] args) throws Exception {
        Configuration conf = getConf();
        conf.set("N", args[0]);

        Job job = Job.getInstance(conf, "K-Nearest-Neighbor mapreduce");        
        job.setJarByClass(KnnDriver.class);

        job.addCacheFile(new URI(args[1]));

        if (args.length != 4) {
            System.err.println("Number of parameter is not correct!");
            System.exit(2);
        }

        job.setMapperClass(KnnMapper.class);
        job.setReducerClass(KnnReducer.class);

        // TODO: specify output types
        job.setOutputKeyClass(LongWritable.class);
        job.setMapOutputValueClass(DistClassPair.class);
        job.setOutputValueClass(Text.class);

        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);

        // TODO: specify input and output DIRECTORIES (not files)
        FileInputFormat.setInputPaths(job, new Path(args[2]));

        Path outputPath = new Path(args[3]);
        FileSystem.get(conf).delete(outputPath, true);
        FileOutputFormat.setOutputPath(job, outputPath);

        return(job.waitForCompletion(true) ? 0 : -1);
    }

}

非常感谢。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题