const char *help = "\
TorchMLP\n\
\n\
This program will train a MLP with tanh outputs for\n\
classification and linear outputs for regression\n";

#include "ConnectedMachine.h"
#include "Linear.h"
#include "FileDataSet.h"
#include "MseCriterion.h"
#include "Tanh.h"
#include "MseMeasurer.h"
#include "ClassMeasurer.h"
#include "TwoClassFormat.h"
#include "OneHotClassFormat.h"
#include "StochasticGradient.h"
#include "GMTrainer.h"
#include "CmdLine.h"

int main(int argc, char **argv)
{
  char *model_file, *test_model_file;
  char *valid_file;
  char *file;

  int n_inputs;
  int n_targets;
  int n_hu;

  int max_load;
  real accuracy;
  real learning_rate;
  real decay;
  int max_iter;
  bool regression;
  int k_fold;
  int the_seed;

  //=================== The command-line ==========================

  // Construct the command line
  CmdLine cmd;

  // Put the help line at the beginning
  cmd.info(help);

  // Ask for arguments
  cmd.addText("\nArguments:");
  cmd.addSCmdArg("file", &file, "the train or test file");
  cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");
  cmd.addICmdArg("n_targets", &n_targets, "output dimension of the data");

  // Propose some options
  cmd.addText("\nModel Options:");
  cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");
  cmd.addBCmdOption("-rm", &regression, false, "regression mode");

  cmd.addText("\nLearning Options:");
  cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");
  cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");
  cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
  cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");

  cmd.addText("\nMisc Options:");
  cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");
  cmd.addICmdOption("-Kfold", &k_fold, -1, "number of subsets for K-fold cross-validation");
  cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load");
  cmd.addSCmdOption("-valid", &valid_file, "", "validation file, if you want it");
  cmd.addSCmdOption("-sm", &model_file, "", "file to save the model");
  cmd.addSCmdOption("-test", &test_model_file, "", "model file to test");

  // Read the command line
  cmd.read(argc, argv);

  // If the user didn't give any random seed,
  // generate a random random seed...
  if(the_seed == -1)
    seed();
  else
    manual_seed((long)the_seed);

  //=================== Create the MLP... =========================
  ConnectedMachine MLP;

  // Create the layers of the MLP
  Linear hidden_linear(n_inputs, n_hu);
  Tanh hidden_nlinear(n_hu);
  Linear output_linear(n_hu, n_targets);
  Tanh output_nlinear(n_targets);

  // Initialize the layers
  hidden_linear.init();
  hidden_nlinear.init();
  output_linear.init();
  output_nlinear.init();

  // Add the layers (Full Connected Layers) to the MLP
  MLP.addFCL(&hidden_linear);
  MLP.addFCL(&hidden_nlinear);
  MLP.addFCL(&output_linear);

  // If regression, don't add the tanh output layer
  if(!regression)
    MLP.addFCL(&output_nlinear);

  // Initialize the MLP
  MLP.init();
 

  //=================== DataSets & Measurers... ===================

  // Create the training dataset (normalize inputs)
  FileDataSet data(file, n_inputs, n_targets, false, max_load);
  data.setBOption("normalize inputs", true);
  data.init();

  // The list of measurers...
  List *measurers = NULL;

  // The class format
  ClassFormat *class_format = NULL;
  if(!regression)
  {
    if(n_targets == 1)
      class_format = new TwoClassFormat(&data);
    else
      class_format = new OneHotClassFormat(&data);
  }

  // The validation set...
  FileDataSet *valid_data = NULL;
  MseMeasurer *valid_mse_meas = NULL;
  ClassMeasurer *valid_class_meas = NULL;

  // Create a validation set, if any
  if(strcmp(valid_file, ""))
  {
    // Load the validation set and normalize it with the
    // values in the train dataset
    valid_data = new FileDataSet(valid_file, n_inputs, n_targets);
    valid_data->init();
    valid_data->normalizeUsingDataSet(&data);

    // Create a MSE measurer and an error class measurer
    // on the validation dataset (if we are not in regression)
    valid_mse_meas = new MseMeasurer(MLP.outputs, valid_data, "the_valid_mse");
    valid_mse_meas->init();
    addToList(&measurers, 1, valid_mse_meas);

    if(!regression)
    {
      valid_class_meas = new ClassMeasurer(MLP.outputs, valid_data, class_format, "the_valid_class_err");
      valid_class_meas->init();
      addToList(&measurers, 1, valid_class_meas);
    }
  }

  // Measurers on the training dataset
  MseMeasurer *mse_meas = new MseMeasurer(MLP.outputs, &data, "the_mse");
  mse_meas->init();
  addToList(&measurers, 1, mse_meas);

  ClassMeasurer *class_meas = NULL;
  if(!regression)
  {
    class_meas = new ClassMeasurer(MLP.outputs, &data, class_format, "the_class_err");
    class_meas->init();
    addToList(&measurers, 1, class_meas);
  }

  //=================== The Trainer ===============================
 
  // The criterion for the GMTrainer (MSE criterion)
  MseCriterion mse(n_targets);
  mse.init();

  // The optimizer for the GMTrainer
  StochasticGradient opt;
  opt.setIOption("max iter", max_iter);
  opt.setROption("end accuracy", accuracy);
  opt.setROption("learning rate", learning_rate);
  opt.setROption("learning rate decay", decay);

  // The Gradient Machine Trainer
  GMTrainer trainer(&MLP, &data, &mse, &opt);

  //=================== Let's go... ===============================

  // Print the number of parameter of the MLP (just for fun)
  message("Number of parameters: %d", MLP.n_params);

  // If the user provides a previously trained model,
  // test it...
  if( strcmp(test_model_file, "") )
  {
    trainer.load(test_model_file);
    trainer.test(measurers);
  }

  // ...else...
  else
  {
    // If the user provides a number for the K-fold validation,
    // do a K-fold validation
    if(k_fold > 0)
      trainer.crossValidate(k_fold, NULL, measurers);

    // Else, train the model
    else
      trainer.train(measurers);

    // Save the model if the user provides a name for that
    if( strcmp(model_file, "") )
      trainer.save(model_file);
  }

  //=================== Quit... ===================================
  if(strcmp(valid_file, ""))
  {
    delete valid_data;
    delete valid_mse_meas;
    if(!regression)
      delete valid_class_meas;
  }

  delete mse_meas;
  if(!regression)
  {
    delete class_meas;
    delete class_format;
  }

  freeList(&measurers);

  return(0);
}