将单个java感知器转换为3类感知器

fhity93d  于 2021-07-03  发布在  Java
关注(0)|答案(0)|浏览(233)

我使用本教程构建了以下感知器:https://www.youtube.com/watch?v=4aksmtjhweq
然后我修改它来接受我自己的数据集beer\u modified.csv,其中几行如下所示:
41.72123894,0.503275756,2.628181818,淡啤酒,4.015384615,16.73,10.45278947,13.44,55.33714286 45.52654867,0.189464469,1.750909091,淡啤酒,3.889230769,16.36,6.581526316,11.76,54.91142857 40.88053097,0.292846105,3.71727,粗壮啤酒,4.356923077,17.51,4.064736842,7.44,69.30285714
但是,我需要根据style属性对数据进行分类,即ale、lager或stout。我意识到单个感知器不起作用,但我发现:https://datascience.stackexchange.com/questions/2255/single-layer-perceptron-with-three-classes
这表明我可以修改代码,将数据处理为3个类。然而,我不知道如何实现这一点。我希望能朝着正确的方向努力,因为我现在真的迷失了方向。谢谢你的阅读。
我现在拥有的代码:

import java.io.FileNotFoundException;
import java.text.DecimalFormat;
import java.util.Scanner;

class Perceptron

{

    private static final String delimiter = ",";
    private static int max = 100;
    private static double learnRate = 0.1;
    private static int instances = 154;
    private static int threshold = 0;

    public static void main (String [] args)

    {

        double calorific_value [] = new double [instances];
        double nitrogen [] = new double [instances];
        double turbidity [] = new double [instances];
        String style [] = new String [instances];
        double alcohol [] = new double [instances];
        double sugars [] = new double [instances];
        double bitterness [] = new double [instances];
        double colour [] = new double [instances];
        double degree_of_fermentation [] = new double [instances];

        int loopNum = 0;

        File myFile = new File ("resources/beer_modified.csv");

        Scanner fileScanner = null;

        try

        {

            fileScanner = new Scanner (myFile);

            fileScanner.useDelimiter (delimiter);

            while (fileScanner.hasNextLine ())

            {

                String line = fileScanner.nextLine ();

                Scanner lineScanner = new Scanner (line);

                lineScanner.useDelimiter (delimiter);

                while (lineScanner.hasNext ())

                {

                    calorific_value [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (calorific_value [loopNum]);

                    nitrogen [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (nitrogen [loopNum]);

                    turbidity [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (turbidity [loopNum]);

                    style [loopNum] = lineScanner.next ();
                    // System.out.println (style [loopNum]);

                    alcohol [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (alcohol [loopNum]);

                    sugars [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (sugars [loopNum]);

                    bitterness [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (bitterness [loopNum]);

                    colour [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (colour [loopNum]);

                    degree_of_fermentation [loopNum] = Double.parseDouble (lineScanner.next ());
                    // System.out.println (degree_of_fermentation [loopNum]);

                    loopNum ++;
                    // System.out.println ("Loopnumber : " + loopNum);

                }

                lineScanner.close ();

            }

        }

        catch (FileNotFoundException beer)

        {

            beer.printStackTrace ();

        }

        finally

        {

            fileScanner.close ();

        }

        int results [] = new int [instances];

        double randomWeights [] = new double [9];
        double error1;
        double error2;

        int iteration;
        int result;

        randomWeights [0] = randomWeight (0, 1);
        randomWeights [1] = randomWeight (0, 1);
        randomWeights [2] = randomWeight (0, 1);
        randomWeights [3] = randomWeight (0, 1);
        randomWeights [4] = randomWeight (0, 1);
        randomWeights [5] = randomWeight (0, 1);
        randomWeights [6] = randomWeight (0, 1);
        randomWeights [7] = randomWeight (0, 1);
        randomWeights [8] = randomWeight (0, 1);

        iteration = 0;

        do

        {

            iteration++;
            error2 = 0;

            for (int l = 0; l < instances; l++)

            {

                result = outputMethod (threshold, randomWeights, calorific_value [l], nitrogen [l], turbidity [l], alcohol [l], sugars [l],
                        bitterness [l], colour [l], degree_of_fermentation [l]);

                error1 = results [l] - result;

                randomWeights [0] += learnRate * error1 * calorific_value [l];
                randomWeights [1] += learnRate * error1 * nitrogen [l];
                randomWeights [2] += learnRate * error1 * turbidity [l];
                randomWeights [3] += learnRate * error1 * alcohol [l];
                randomWeights [4] += learnRate * error1 * sugars [l];
                randomWeights [5] += learnRate * error1 * bitterness [l];
                randomWeights [6] += learnRate * error1 * colour [l];
                randomWeights [7] += learnRate * error1 * degree_of_fermentation [l];
                randomWeights [8] += learnRate * error1;

                error2 += (error1 * error1);

            }

            System.out.println ("Iteration " + iteration + " : RMSE = " + Math.sqrt (error2 / instances));

        } while ((error2 != 0) && (iteration <= max));

        System.out.println ("\n=======\nDecision Boundary Equation:");
        System.out.println (randomWeights [0] + " * a + " + randomWeights [1] + " * b + " + randomWeights [2] + " * c + " + randomWeights [3] + " * d + " +
                randomWeights [4] + " * e + " + randomWeights [5] + " * f + " + randomWeights [6] + " * g + " + randomWeights [7] + " * h + " + randomWeights [8] + " = 0");

    }

    public static double randomWeight (int min , int max)

    {

        DecimalFormat df = new DecimalFormat ("#.####");
        double d = min + Math.random () * (max - min);
        String s = df.format (d);
        double x = Double.parseDouble (s);
        return x;

    }

    // Activation Function
    public static int outputMethod (int threshold, double weights [], double a, double b, double c, double d, double e, double f, double g, double h)

    {

        double sum = a * weights [0] + b * weights [1] + c * weights [2] + d * weights [3] + e * weights [4] + f * weights [5] + g * weights [6] + h * weights [7] + weights [8];

        return (sum >= threshold) ? 1 : 0;

    }

}

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题