Maximum Acyclic Subgraph - Multiple Sequence Alignment
MaximumAcyclicSubgraph
LinearNet.h
1 #ifndef LINEARNET_H
2 #define LINEARNET_H
3 
4 #include <iostream>
5 #include <torch/torch.h>
6 
8 class LinearNet : public torch::nn::Module {
9  public:
10  LinearNet(){};
11 
12  LinearNet(unsigned ds) : dim_state(ds) {
13  // Construct and register a Linear submodule
14  lin_mod = register_module("fc", torch::nn::Linear(dim_state, 1));
15  // for (auto& p : this->parameters()) {
16  //torch::nn::init::constant_(p, 0);
17  //p.uniform_(0,0);
18  //}
19  }
20 
21  // Implement the Net's algorithm.
22  torch::Tensor forward(torch::Tensor& x) {
23  return lin_mod->forward(x.reshape({x.size(0), dim_state}));
24  }
25 
26  // Use one of many "standard library" modules.
27  unsigned dim_state;
28  torch::nn::Linear lin_mod{nullptr};
29 };
30 #endif
This class defines a new Module.
Definition: LinearNet.h:8