它在玩点和框的游戏。我一直在和超参数乱搞,但什么都没变。我完全卡住了,需要帮助。如果有人能检查一下代码,看看是否有什么设置错误。我不知道是否正确设置的区域是mdp中的观察空间。
以下是解决的类:
public class testNeural {
public static void main(String args[]) throws IOException, InterruptedException {
GameBoard r = new GameBoard(3,3);
DQNPolicy<testState> t = dots();
}
private static DQNPolicy<testState> dots() throws IOException {
QLearningConfiguration DOTS_QL = QLearningConfiguration.builder()
.seed(Long.valueOf(132)) //Random seed (for reproducability)
.maxEpochStep(500) // Max step By epoch
.maxStep(1000) // Max step
.expRepMaxSize(15000) // Max size of experience replay
.batchSize(Graph.getEdgeList().size()) // size of batches
.targetDqnUpdateFreq(100) // target update (hard)
.updateStart(10) // num step noop warmup
.rewardFactor(0.1) // reward scaling
.gamma(0.95) // gamma
.errorClamp(1.0) // /td-error clipping
.minEpsilon(0.3f) // min epsilon
.epsilonNbStep(10) // num step for eps greedy anneal
.doubleDQN(false) // double DQN
.build();
DQNDenseNetworkConfiguration DOTS_NET =
DQNDenseNetworkConfiguration.builder()
.l2(0)
.updater(new RmsProp(0.000025))
.numHiddenNodes(50)
.numLayers(10)
.build();
// The neural network used by the agent. Note that there is no need to specify the number of inputs/outputs.
// These will be read from the gym environment at the start of training.
testEnv env = new testEnv();
QLearningDiscreteDense<testState> dql = new QLearningDiscreteDense<testState>(env, DOTS_NET, DOTS_QL);
System.out.println(dql.toString());
dql.train();
return dql.getPolicy();
}
}
mdp环境:
public class testEnv implements MDP<testState, Integer, DiscreteSpace> {
DiscreteSpace actionSpace = new DiscreteSpace(Graph.getEdgeList().size());
// takes amount of possible edges ^
ObservationSpace<testState> observationSpace = new ArrayObservationSpace(new int[] {Graph.getEdgeList().size()});
private testState state = new testState(Graph.getMatrix(),0);
private NeuralNetFetchable<IDQN> fetchable;
boolean illegal=false;
public testEnv(){}
@Override
public ObservationSpace<testState> getObservationSpace() {
return observationSpace;
}
@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}
@Override
public testState reset() {
// System.out.println("RESET");
try {
GameBoard r = new GameBoard(3,3);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
return new testState(Graph.getMatrix(),0);
}
@Override
public void close() { }
@Override
public StepReply<testState> step(Integer action) {
// System.out.println("Action: "+action);
// System.out.println(Arrays.deepToString(Graph.getMatrix()));
int reward=0;
try {
placeEdge(action);
} catch (InterruptedException e) {
e.printStackTrace();
}
// change the getPlayer1 to whichever player the neural is
// System.out.println("step: "+state.step);
if(!illegal) {
System.out.println("Not Illegal");
if (isDone()) {
if (Graph.getPlayer1Score() > Graph.getPlayer2Score()) {
reward = 5;
} else {
reward = -5;
}
}else {
if (Graph.numOfMoves < 1) {
if (Graph.player1Turn) {
Graph.player1Turn = false;
} else {
Graph.player1Turn = true;
}
Graph.setNumOfMoves(1);
while (Graph.numOfMoves > 0) {
// System.out.println(Arrays.deepToString(Graph.getMatrix()));
if (!isDone()) {
Graph.getRandomBot().placeRandomEdge();
} else {
Graph.numOfMoves = 0;
if (Graph.getPlayer1Score() > Graph.getPlayer2Score()) {
reward = 5;
} else {
reward = -5;
}
}
}
if (!isDone()) {
if (Graph.player1Turn) {
Graph.player1Turn = false;
} else {
Graph.player1Turn = true;
}
Graph.setNumOfMoves(1);
}
}
}
}else{
reward=-100000;
illegal=false;
}
testState t = new testState(Graph.getMatrix(), state.step + 1);
state=t;
return new StepReply<>(t, reward, isDone(), null);
}
@Override
public boolean isDone() {
return gameThread.checkFinished();
}
@Override
public MDP<testState, Integer, DiscreteSpace> newInstance() {
testEnv test = new testEnv();
test.setFetchable(fetchable);
return test;
}
public void setFetchable(NeuralNetFetchable<IDQN> fetchable) {
this.fetchable = fetchable;
}
}
国家级:
public class testState implements Encodable {
int[][] matrix;
int step;
public testState(int[][] m,int step){
matrix=m;
this.step=step;
}
@Override
public double[] toArray() {
double[] array = new double[matrix.length*matrix[0].length];
int i=0;
for(int a=0;a< matrix.length;a++){
for(int b=0;b<matrix[0].length;b++){
array[i]= matrix[a][b];
i++;
}
}
return array;
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return null;
}
@Override
public Encodable dup() {
return null;
}
}
暂无答案!
目前还没有任何答案,快来回答吧!