Maximum Acyclic Subgraph - Multiple Sequence Alignment
MaximumAcyclicSubgraph
RLDataset.h
1 #include "../SimpleAgent/Agent.h"
2 #include <torch/torch.h>
3 #include <iostream>
4 #include "../SimpleAgent/Episode.h"
5 using std::vector;
6 
7 #ifndef RLDATASET_H
8 #define RLDATASET_H
9 
10 class RLDataset : public torch::data::Dataset<RLDataset>
11 {
12  private:
13  torch::Tensor states, scores, actions;
14 
15  public:
16  unsigned int numbEpisodes;
17  vector<Episode> episodes; // vector of episodes
23  RLDataset(vector<Episode>& episodes);
24  RLDataset(){};
25 
26  void set(vector<Episode>& episodes);
27 
28  torch::data::Example<> get(size_t index) override {
29  return {states[index], scores[index]};
30  }
31 
32  torch::optional<size_t> size() const override {
33  return states.size(0);
34  }
35 };
36 #endif