SnakeAI使用mlp和ga来学习,即使经过几千代的进化,也不会表现出智能行为

qcuzuvrc  于 2021-07-06  发布在  Java
关注(0)|答案(1)|浏览(339)

我是一名高中生,正在为我的cs研究班做一个项目(我很幸运有机会参加这样的班)!该项目是使人工智能学习流行的游戏,蛇,与多层感知器(mlp)学习,通过遗传算法(ga)。这个项目的灵感来源于我在youtube上看到的许多视频,它们完成了我刚才描述的内容,你可以在这里和这里看到。我使用javafx和一个名为neuroph的人工智能库编写了上述项目。
这就是我的程序目前的样子:

这个名字是不相关的,因为我有一个我用来生成它们的名词和形容词的列表(我想这会让它更有趣)。括号中的分数是该代中的最佳分数,因为一次只显示一条蛇。
在繁殖的时候,我把x%的蛇设定为父母(在这个例子中是20条)。孩子的数量然后平均分配给每对蛇的父母。在这种情况下,“基因”是mlp的权重。由于我的库不支持偏差,我在输入层添加了一个偏差神经元,并将其连接到每层中的所有其他神经元,以使其权重充当偏差(如这里的线程所述)。每一条蛇的孩子都有50%的机会获得父母中任何一方的基因。基因也有5%的几率发生突变,它被设置为-1.0到1.0之间的随机数。
每条蛇的mlp有3层:18个输入神经元、14个隐藏神经元和4个输出神经元(每个方向)。我输入的是头的x,头的y,食物的x,食物的y,还有剩下的步数。它还会朝4个方向看,并检查到食物、墙和自身的距离(如果看不到,则设置为-1.0)。还有一个偏向神经元,我说的是,在加上它之后,它会把数字变成18。
我计算一条蛇的分数的方法是通过我的适应度函数,它是(苹果)× 活了5秒以上/2)
这是我的gamlagent.java,所有mlp和ga的东西都发生在这里。

package agents;

import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;

/**
 *
 * @author Preston Tang
 *
 * GAMLPAgent stands for Genetic Algorithm Multi-Layer Perceptron Agent
 */
public class GAMLPAgent implements Comparable<GAMLPAgent> {

    public Snake mask;

    private final MultiLayerPerceptron mlp;

    private final int width;
    private final int height;
    private final double size;

    private final double mutationRate = 0.05;

    public GAMLPAgent(Snake mask, int width, int height, double size) {
        this.mask = mask;
        this.width = width;
        this.height = height;
        this.size = size;

        //Input: x of head, y of head, x of food, y of food, steps left
        //Input: 4 directions, check for distance to food, wall, and self  + 1 bias neuron (18 total)
        //6 hidden perceptrons (2 hidden layer(s))
        //Output: A direction, 4 possibilities
        mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 18, 14, 4);
        //Adding connections
        List<Layer> layers = mlp.getLayers();

        for (int r = 0; r < layers.size(); r++) {
            for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
                mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
            }
        }

//        System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getOutConnections());
        mlp.randomizeWeights();

//        System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
    }

    public void compute() {
        if (mask.isAlive()) {
            Rectangle head = mask.getSnakeParts().get(0);
            Rectangle food = mask.getFood();

            double headX = head.getX();
            double headY = head.getY();
            double foodX = mask.getFood().getX();
            double foodY = mask.getFood().getY();
            int stepsLeft = mask.getSteps();

            double foodL = -1.0, wallL, selfL = -1.0;
            double foodR = -1.0, wallR, selfR = -1.0;
            double foodU = -1.0, wallU, selfU = -1.0;
            double foodD = -1.0, wallD, selfD = -1.0;

            //The 4 directions
            //Left Direction
            if (head.getY() == food.getY() && head.getX() > food.getX()) {
                foodL = head.getX() - food.getX();
            }

            wallL = head.getX() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() > part.getX()) {
                    selfL = head.getX() - part.getX();
                    break;
                }
            }

            //Right Direction
            if (head.getY() == food.getY() && head.getX() < food.getX()) {
                foodR = food.getX() - head.getX();
            }

            wallR = size * width - head.getX();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() < part.getX()) {
                    selfR = part.getX() - head.getX();
                    break;
                }
            }

            //Up Direction
            if (head.getX() == food.getX() && head.getY() < food.getY()) {
                foodU = food.getY() - head.getY();
            }

            wallU = size * height - head.getY();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() < part.getY()) {
                    selfU = part.getY() - head.getY();
                    break;
                }
            }

            //Down Direction
            if (head.getX() == food.getX() && head.getY() > food.getY()) {
                foodD = head.getY() - food.getY();
            }

            wallD = head.getY() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() > part.getY()) {
                    selfD = head.getY() - food.getY();
                    break;
                }
            }

            mlp.setInput(
                    headX, headY, foodX, foodY, stepsLeft,
                    foodL, wallL, selfL,
                    foodR, wallR, selfR,
                    foodU, wallU, selfU,
                    foodD, wallD, selfD, 1);

            mlp.calculate();

            if (getIndexOfLargest(mlp.getOutput()) == 0) {
                mask.setDirection(Direction.UP);
            } else if (getIndexOfLargest(mlp.getOutput()) == 1) {
                mask.setDirection(Direction.DOWN);
            } else if (getIndexOfLargest(mlp.getOutput()) == 2) {
                mask.setDirection(Direction.LEFT);
            } else if (getIndexOfLargest(mlp.getOutput()) == 3) {
                mask.setDirection(Direction.RIGHT);
            }
        }
    }

    public double[][] breed(GAMLPAgent agent, int num) {
        //Converts Double[] to double[]
        //https://stackoverflow.com/questions/1109988/how-do-i-convert-double-to-double
        double[] parent1 = Stream.of(mlp.getWeights()).mapToDouble(Double::doubleValue).toArray();
        double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapToDouble(Double::doubleValue).toArray();

        double[][] childGenes = new double[num][parent1.length];

        for (int r = 0; r < num; r++) {
            for (int c = 0; c < childGenes[r].length; c++) {
                if (new Random().nextInt(100) <= mutationRate * 100) {
                    childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0, 1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
                } else {
                    childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
                }
            }
        }

        return childGenes;
    }

    public MultiLayerPerceptron getMLP() {
        return mlp;
    }

    public void setMask(Snake mask) {
        this.mask = mask;
    }

    public Snake getMask() {
        return mask;
    }

    public int getIndexOfLargest(double[] array) {
        if (array == null || array.length == 0) {
            return -1; // null or empty
        }
        int largest = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[largest]) {
                largest = i;
            }
        }
        return largest; // position of the first largest found
    }

    @Override
    public int compareTo(GAMLPAgent t) {
        if (this.getMask().getScore() < t.getMask().getScore()) {
            return -1;
        } else if (t.getMask().getScore() < this.getMask().getScore()) {
            return 1;
        }
        return 0;
    }

    public void debugLocation() {
        Rectangle head = mask.getSnakeParts().get(0);
        Rectangle food = mask.getFood();
        System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
        System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getOutput()));
    }

    public void debugInput() {
        String s = "";
        for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
            s += mlp.getInputNeurons().get(i).getOutput() + " ";
        }
        System.out.println(s);
    }

    public double[] getOutput() {
        return mlp.getOutput();
    }
}

这是我代码的主要类geneticsnake2.java,游戏循环位于这里,我将基因分配给子蛇(我知道可以更干净地完成)。

package main;

import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;

/**
 *
 * @author Preston Tang
 */
public class GeneticSnake2 extends Application {

    private final int width = 45;
    private final int height = 40;

    private final double displaySize = 120;
    private final double size = 12;

    private final Color pathColor = Color.rgb(120, 120, 120);
    private final Color wallColor = Color.rgb(50, 50, 50);

    private final int initSnakeLength = 2;

    private final int populationSize = 1000;

    private int generation = 0;

    private int initSteps = 100;
    private int stepsIncrease = 50;

    private double parentPercentage = 0.2;

    private final ArrayList<Color> snakeColors = new ArrayList() {
        {
            add(Color.GREEN);
            add(Color.RED);
            add(Color.YELLOW);
            add(Color.BLUE);
            add(Color.MAGENTA);
            add(Color.PINK);
            add(Color.ORANGERED);
            add(Color.BLACK);
            add(Color.GOLDENROD);
            add(Color.WHITE);
        }
    };

    private final ArrayList<Snake> snakes = new ArrayList<>();

    private final ArrayList<GAMLPAgent> agents = new ArrayList<>();

    private long initTime = System.nanoTime();

    @Override
    public void start(Stage stage) {
        Pane root = new Pane();
        Pane graphics = new Pane();
        graphics.setPrefHeight(height * size);
        graphics.setPrefWidth(width * size);
        graphics.setTranslateX(0);
        graphics.setTranslateY(displaySize);

        Pane display = new Pane();
        display.setStyle("-fx-background-color: BLACK");
        display.setPrefHeight(displaySize);
        display.setPrefWidth(width * size);
        display.setTranslateX(0);
        display.setTranslateY(0);

        root.getChildren().add(display);

        SnakeGrid sg = new SnakeGrid(pathColor, wallColor, width, height, size);

        //Parsing "adjectives.txt" and "nouns.txt" to form possible names
        ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
        ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));

        //Initializing the population
        for (int i = 0; i < populationSize; i++) {
            //Get random String from lists and capitalize first letter
            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
            adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);

            String noun = nouns.get(new Random().nextInt(nouns.size()));
            noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);

            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

            //We want to see the first snake
            if (i == 0) {
                InfoBar bar = new InfoBar();
                bar.getStatusText().setText("Status: Alive");
                bar.getStatusText().setFill(Color.GREENYELLOW);
                bar.getSizeText().setText("Population Size: " + populationSize);

                Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                bar.getNameText().setText("Name: " + snake.getName());

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake, width, height, size));

            } else {
                Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake, width, height, size));
            }
        }

        //Focused on original snake
        display.getChildren().add(snakes.get(0).getInfoBar());

        graphics.getChildren().addAll(sg);

        graphics.getChildren().addAll(snakes.get(0));

        root.getChildren().add(graphics);

        //Add the speed controller (slider)
        Slider slider = new Slider(1, 10, 10);
        slider.setTranslateX(205);
        slider.setTranslateY(75);
        slider.setDisable(true);

        root.getChildren().add(slider);

        Scene scene = new Scene(root, width * size, height * size + displaySize);
        stage.setScene(scene);

        //Fixes the setResizable bug
        //https://stackoverflow.com/questions/20732100/javafx-why-does-stage-setresizablefalse-cause-additional-margins
        stage.setTitle("21-GeneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
        stage.setResizable(false);
        stage.sizeToScene();
        stage.show();

        AnimationTimer timer = new AnimationTimer() {
            private long lastUpdate = 0;

            @Override
            public void handle(long now) {
                if (now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
                    lastUpdate = now;

                    int alive = populationSize;
                    for (int i = 0; i < snakes.size(); i++) {
                        Snake snake = snakes.get(i); //Current snake

                        if (i == 0) {
                            Collections.sort(agents);
                            snake.getInfoBar().getScoreText().setText("Score: " + snake.getScore() + " (" + agents.get(agents.size() - 1).getMask().getScore() + ")");
                        }

                        if (!snake.isAlive()) {
                            alive--;

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getStatusText().setText("Status: Dead");
                                snake.getInfoBar().getStatusText().setFill(Color.RED);
                                graphics.getChildren().remove(snake);
                            }

                        } else {
                            //If out of steps
                            if (snake.getSteps() <= 0) {
                                snake.setAlive(false);
                            }

                            //Bounds Detection (left right up down)
                            if (snake.getSnakeParts().get(0).getX() >= width * size
                                    || snake.getSnakeParts().get(0).getX() <= 0
                                    || snake.getSnakeParts().get(0).getY() >= height * size
                                    || snake.getSnakeParts().get(0).getY() <= 0) {
                                snake.setAlive(false);
                            }

                            //Self-Collision Detection
                            for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
                                if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
                                        && snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
                                    snakes.get(o).setAlive(false);
                                }
                            }

                            int rate = (int) slider.getValue();
                            int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);

                            agents.get(i).compute();
                            snake.manageMovement();
                            snake.setSecondsAlive(seconds);

//                            agents.get(0);
//                            System.out.println(Arrays.toString(agents.get(0).getOutput()));
//                            
//                            System.out.println("\n\n\n\n\n\n\n");
                            //Expression to calculate score
                            double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2, snake.getConsumed()) + Math.pow(snake.getConsumed(), 2.1) * 500)
//        - (Math.pow(snake.getConsumed(), 1.2) * Math.pow(0.25 * snake.getSteps(), 1.3));

                            snake.setScore(Math.round(exp * 100.0) / 100.0);

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
                                snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
                                snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
                                snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
                            }
                        }
                    }

                    //Reset and breed
                    if (alive == 0) {
                        //Ascending order
                        initTime = System.nanoTime();
                        generation++;
                        graphics.getChildren().clear();
                        graphics.getChildren().addAll(sg);
                        snakes.clear();

                        //x% of snakes are parents
                        int parentNum = (int) (populationSize * parentPercentage);

                        //Faster odd number check
                        if ((parentNum & 1) != 0) {
                            //If odd make even
                            parentNum += 1;
                        }

                        for (int i = 0; i < parentNum; i += 2) {
                            //Get the 2 parents, sorted by score
                            GAMLPAgent p1 = agents.get(populationSize - (i + 2));
                            GAMLPAgent p2 = agents.get(populationSize - (i + 1));

                            //Produce the next generation
                            double[][] childGenes = p1.breed(p2, ((populationSize - parentNum) / parentNum) * 2);

                            //Debugs Genes
//                            System.out.println(Arrays
//                                    .stream(childGenes)
//                                    .map(Arrays::toString)
//                                    .collect(Collectors.joining(System.lineSeparator())));
                            //Soft copy
                            ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);

                            for (int o = 0; o < childGenes.length; o++) {
                                temp.get(o).getMLP().setWeights(childGenes[o]);
                            }

                            //Add the genes of every pair of parents to the children
                            for (int o = 0; o < childGenes.length; o++) {
                                //Useful debug message
//                                System.out.println("ParentNum: " + parentNum
//                                        + " ChildPerParent: " + (populationSize - parentNum) / parentNum
//                                        + " Index: " + (o + (i / 2 * childGenes.length))
//                                        + " ChildGenesNum: " + childGenes.length
//                                        + " Var O: " + o);

                                //Adds the genes of the temp to the agents
                                agents.set((o + (i / 2 * childGenes.length)), temp.get(o));
                            }
//                            System.out.println("\n\n\n\n\n\n");
                        }

                        //Debugging the snakes' genes to a file
//                        String str = "";
//                        for (int i = 0; i < agents.size(); i++) {
//                            str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+  "\n\n\n";
//                        }
//
//                        printToFile(str, "gen" + generation);

                        for (int i = 0; i < populationSize; i++) {
                            //Get random String from lists and capitalize first letter
                            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
                            adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);

                            String noun = nouns.get(new Random().nextInt(nouns.size()));
                            noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);

                            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

                            //We want to see the first snake
                            if (i == 0) {
                                InfoBar bar = new InfoBar();
                                bar.getStatusText().setText("Status: Alive");
                                bar.getStatusText().setFill(Color.GREENYELLOW);
                                bar.getSizeText().setText("Population Size: " + populationSize);

                                Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                                bar.getNameText().setText("Name: " + snake.getName());
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            } else {
                                Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            }
                        }

                        graphics.getChildren().add(snakes.get(0));
                        display.getChildren().clear();

                        //Focused on original snake at first
                        display.getChildren().add(snakes.get(0).getInfoBar());
                    }
                }
            }
        };
        //Starts the infinite loop
        timer.start();
    }

    public String readFile(File f) {
        String content = "";
        try {
            content = new Scanner(f).useDelimiter("\\Z").next();
        } catch (FileNotFoundException ex) {
            System.err.println("Error: Unable to read " + f.getName());
        }
        return content;
    }

    public void printToFile(String str, String name) {
        FileWriter fileWriter;
        try {
            fileWriter = new FileWriter(name + ".txt");
            try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
                bufferedWriter.write(str);
            }

        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    public static void main(String[] args) {
        launch(args);
    }
}

主要的问题是,即使经过几千代,蛇仍然只是跳进墙自杀。在我上面链接的视频中,这些蛇在第5代的时候避开墙壁,获取食物。我怀疑问题出在我给出生的蛇分配基因的主类中。
事实上我已经被困在这几个星期了。以前,我怀疑的问题之一是缺乏投入,因为那时我的投入少得多。但现在,我认为情况已经不是这样了。如果需要的话,我可以试着在4个对角线方向上看,为蛇的mlp添加另外12个输入。我也曾到人工智能不和组织寻求帮助,但还没有真正找到解决办法。
如果需要,我愿意发送我的全部代码,这样你就可以自己运行模拟了。
如果你读到这里,谢谢你抽出时间来帮助我!我非常感激。

jckbn6z7

jckbn6z71#

我不奇怪你的蛇会死。
让我们后退一步。人工智能到底是什么?嗯,这是个搜索问题。我们在一些参数空间中搜索,以找到在给定游戏当前状态下解snake的参数集。你可以想象一个具有全局最小值的参数空间:最好的蛇,犯最少错误的蛇。
所有的学习算法都从这个参数空间的某个点开始,并试图找到随时间变化的全局最大值。首先,让我们考虑一下MLP。MLP通过尝试一组权值,计算损失函数,然后朝着进一步最小化损失的方向迈出一步(梯度下降)来学习。很明显,一个mlp会找到一个最小值,但是它是否能找到一个足够好的最小值是一个问题,有很多训练技巧可以提高这个机会。
另一方面,遗传算法的收敛性很差。首先,我们不要再叫这些遗传算法了。让我们把这些叫做自助餐算法。一个自助餐算法从两个父对象获取两组参数,将它们混合,然后产生一个新的自助餐。是什么让你觉得这个自助餐比这两个都好?你在这儿干什么?你怎么知道它正在接近更好的结果?如果你附加一个损失函数,你怎么知道你所处的空间实际上可以最小化?
我想说的是,遗传算法是无原则的,与自然不同。大自然不只是把密码子放进搅拌机里制造新的dna链,而这正是遗传算法所做的。有一些技术可以增加爬山的时间,但是遗传算法仍然有很多问题。
关键是,不要被冠以这样的名字。遗传算法就是简单的自助餐算法。我的观点是,你的方法行不通,因为gas不能保证在无限迭代后收敛,mlp也不能保证收敛到一个好的全局最小值。
怎么办?好吧,更好的方法是使用适合你问题的学习模式。更好的方法是使用强化学习。佐治亚理工学院有一门很好的关于这个问题的课程。

相关问题