1 #include "../SimpleAgent/Agent.h" 2 #include <torch/torch.h> 4 #include "../SimpleAgent/Episode.h" 10 class RLDataset :
public torch::data::Dataset<RLDataset>
13 torch::Tensor states, scores, actions;
16 unsigned int numbEpisodes;
17 vector<Episode> episodes;
23 RLDataset(vector<Episode>& episodes);
26 void set(vector<Episode>& episodes);
28 torch::data::Example<>
get(
size_t index)
override {
29 return {states[index], scores[index]};
32 torch::optional<size_t> size()
const override {
33 return states.size(0);