5 #include <torch/torch.h> 14 lin_mod = register_module(
"fc", torch::nn::Linear(dim_state, 1));
22 torch::Tensor forward(torch::Tensor& x) {
23 return lin_mod->forward(x.reshape({x.size(0), dim_state}));
28 torch::nn::Linear lin_mod{
nullptr};
This class defines a new Module.
Definition: LinearNet.h:8