java—尝试在自定义mdp环境中使用qlearning每次选择行动0,尽管有沉重的负面奖励

ffscu2ro  于 2021-07-07  发布在  Java
关注(0)|答案(0)|浏览(381)

它在玩点和框的游戏。我一直在和超参数乱搞,但什么都没变。我完全卡住了,需要帮助。如果有人能检查一下代码,看看是否有什么设置错误。我不知道是否正确设置的区域是mdp中的观察空间。
以下是解决的类:

public class testNeural {
    public static void main(String args[]) throws IOException, InterruptedException {
        GameBoard r = new GameBoard(3,3);
        DQNPolicy<testState> t = dots();
    }

    private static DQNPolicy<testState> dots() throws IOException {

        QLearningConfiguration DOTS_QL = QLearningConfiguration.builder()
                .seed(Long.valueOf(132))                //Random seed (for reproducability)
                .maxEpochStep(500)        // Max step By epoch
                .maxStep(1000)           // Max step
                .expRepMaxSize(15000)    // Max size of experience replay
                .batchSize(Graph.getEdgeList().size())            // size of batches
                .targetDqnUpdateFreq(100) // target update (hard)
                .updateStart(10)          // num step noop warmup
                .rewardFactor(0.1)       // reward scaling
                .gamma(0.95)              // gamma
                .errorClamp(1.0)          // /td-error clipping
                .minEpsilon(0.3f)         // min epsilon
                .epsilonNbStep(10)      // num step for eps greedy anneal
                .doubleDQN(false)          // double DQN
                .build();
        DQNDenseNetworkConfiguration DOTS_NET =
                DQNDenseNetworkConfiguration.builder()
                        .l2(0)
                        .updater(new RmsProp(0.000025))
                        .numHiddenNodes(50)
                        .numLayers(10)
                        .build();

        // The neural network used by the agent. Note that there is no need to specify the number of inputs/outputs.
        // These will be read from the gym environment at the start of training.

        testEnv env = new testEnv();
        QLearningDiscreteDense<testState> dql = new QLearningDiscreteDense<testState>(env, DOTS_NET, DOTS_QL);
        System.out.println(dql.toString());
        dql.train();
        return dql.getPolicy();
    }
}

mdp环境:

public class testEnv implements MDP<testState, Integer, DiscreteSpace> {
    DiscreteSpace actionSpace = new DiscreteSpace(Graph.getEdgeList().size());

    // takes amount of possible edges      ^
    ObservationSpace<testState> observationSpace = new ArrayObservationSpace(new int[] {Graph.getEdgeList().size()});
    private testState state = new testState(Graph.getMatrix(),0);
    private NeuralNetFetchable<IDQN> fetchable;
    boolean illegal=false;
    public testEnv(){}
    @Override
    public ObservationSpace<testState> getObservationSpace() {
        return observationSpace;
    }
    @Override
    public DiscreteSpace getActionSpace() {
        return actionSpace;
    }
    @Override
    public testState reset() {
       // System.out.println("RESET");
        try {
            GameBoard r = new GameBoard(3,3);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return new testState(Graph.getMatrix(),0);
    }
    @Override
    public void close() { }
    @Override
    public StepReply<testState> step(Integer action) {
      //  System.out.println("Action: "+action);
     //   System.out.println(Arrays.deepToString(Graph.getMatrix()));
        int reward=0;
        try {
            placeEdge(action);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        // change the getPlayer1 to whichever player the neural is
     //   System.out.println("step: "+state.step);
        if(!illegal) {
            System.out.println("Not Illegal");
            if (isDone()) {
                if (Graph.getPlayer1Score() > Graph.getPlayer2Score()) {
                    reward = 5;
                } else {
                    reward = -5;
                }
            }else {
                if (Graph.numOfMoves < 1) {
                    if (Graph.player1Turn) {
                        Graph.player1Turn = false;
                    } else {
                        Graph.player1Turn = true;
                    }
                    Graph.setNumOfMoves(1);
                    while (Graph.numOfMoves > 0) {
                            //       System.out.println(Arrays.deepToString(Graph.getMatrix()));
                        if (!isDone()) {
                            Graph.getRandomBot().placeRandomEdge();
                        } else {
                            Graph.numOfMoves = 0;
                            if (Graph.getPlayer1Score() > Graph.getPlayer2Score()) {
                                reward = 5;
                            } else {
                                reward = -5;
                            }
                        }
                    }
                    if (!isDone()) {
                        if (Graph.player1Turn) {
                            Graph.player1Turn = false;
                        } else {
                            Graph.player1Turn = true;
                        }
                        Graph.setNumOfMoves(1);
                    }
                }
            }
        }else{
            reward=-100000;
            illegal=false;
        }
        testState t = new testState(Graph.getMatrix(), state.step + 1);
        state=t;
        return new StepReply<>(t, reward, isDone(), null);
    }

    @Override
    public boolean isDone() {
        return gameThread.checkFinished();
    }

    @Override
    public MDP<testState, Integer, DiscreteSpace> newInstance() {
        testEnv test = new testEnv();
        test.setFetchable(fetchable);
        return test;
    }
    public void setFetchable(NeuralNetFetchable<IDQN> fetchable) {
        this.fetchable = fetchable;
    }
}

国家级:

public class testState implements Encodable {
    int[][] matrix;
    int step;
    public testState(int[][] m,int step){
        matrix=m;
        this.step=step;
    }
    @Override
    public double[] toArray() {
        double[] array = new double[matrix.length*matrix[0].length];
        int i=0;
        for(int a=0;a< matrix.length;a++){
            for(int b=0;b<matrix[0].length;b++){
                array[i]= matrix[a][b];
                i++;
            }
        }
        return array;
    }

    @Override
    public boolean isSkipped() {
        return false;
    }

    @Override
    public INDArray getData() {
        return null;
    }

    @Override
    public Encodable dup() {
        return null;
    }
}

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题