4 #include "../../alignment/State.h" 5 #include "../../alignment/Graph.h" 6 #include "../../alignment/Node.h" 7 #include "../../alignment/Edge.h" 8 #include "../SimpleAgent/Agent.h" 9 #include "../SimpleAgent/Episode.h" 10 #include "valueMLmodel.h" 12 #include "LinearNet.h" 13 #include "RLDataset.h" 14 #include <torch/torch.h> 23 TrainingSet(
unsigned int learningRepetitions,
unsigned int epochs,
unsigned int numbEpisodes,
unsigned int batchSize,
float learningRate);
31 void setLearningRepetitions(
unsigned int learningRepetitions);
33 void setEpochs(
unsigned int epochs);
35 void setNumbEpisodes(
unsigned int numbEpisodes);
37 void setBatchSize(
unsigned int batchSize);
39 void setLearningRate(
float learningRate);
44 void train(
Agent* agent);
Definition: TrainingSet.h:20
unsigned int epochs
Number of epochs.
Definition: TrainingSet.h:26
unsigned int learningRepetitions
Number of times the agent gets trained.
Definition: TrainingSet.h:25
This Agent class selects edges according to a policy.
Definition: Agent.h:20
unsigned int numbEpisodes
The number of episodes.
Definition: TrainingSet.h:29
float learningRate
The rate of learning (alpha)
Definition: TrainingSet.h:28
unsigned int batchSize
The size of a batch.
Definition: TrainingSet.h:27