1  /* libtest_sin.cpp
  2   *     - Shows the libsnnl usage by learning the sinus function
  3   * --------------------------------------------------------------------------
  4   * howto use:
  5   * to generate a default network
  6   *    $> ./libtest_sin
  7   * to generate a network with non defaults
  8   *    $> ./libtest_sin <MomentumTerm> <OptimalTolerance> <WeightDecay>
  9   * to load a network from file
 10   *    $> ./libtest_sin f
 11   * process and view the output
 12   *    $> perl process_dumps.pl
 13   *    $> gnuplot sin_plot.dem
 14   * or all in one
 15   *    $> ./libtest_sin 0.1 0.0 && ./process_dumps.pl && gnuplot sin_plot.dem
 16   * __________________________________________________________________________*/
 17
 18  #include <iostream>
 19  #include <fstream>
 20  #include <math.h>
 21
 22  #include "snnl_network.h"
 23  #include "snnl_teacher.h"
 24
 25
 26  static void testNetwork (Network *net, Teacher *teacher);
 27
 28
 29  int
 30  main (int argc, char *argv[])
 31  {
 32      Network * net;
 33      Teacher * teacher;
 34
 35      vector <unsigned int> gen;
 36      gen.push_back (1);
 37      gen.push_back (8);
 38      gen.push_back (6);
 39      gen.push_back (1);
 40
 41      if (argc==2 && argv[1][0]=='f') {
 42          /* load a network from file */
 43          net = new Network ();
 44          teacher = new Teacher (net);
 45          net->load ("dump/sin_network.net");
 46      } else {
 47          /* or generate a new network */
 48          net = new Network (gen, "snnl - 'sinus emulator'");
 49          teacher = new Teacher (net);
 50          teacher->randomizeParameters ();
 51          teacher->setLearningParameter (0.2);
 52
 53          if (argc>=2) {
 54              teacher->setMomentumTermParameter (atof(argv[1]));
 55          } else {
 56              teacher->setMomentumTermParameter (0.0);
 57          }
 58
 59          if (argc>=3) {
 60              teacher->setOptimalTolerance (atof(argv[2]));
 61          } else {
 62              teacher->setOptimalTolerance (0.0);
 63          }
 64
 65          if (argc>=4) {
 66              teacher->setWeightDecayParameter (atof(argv[3]));
 67          } else {
 68              teacher->setWeightDecayParameter (0.00004);
 69          }
 70      }
 71
 72      cout << teacher->getLearningParameter ()     << " "
 73           << teacher->getMomentumTermParameter () << " "
 74           << teacher->getOptimalTolerance ()      << " "
 75           << teacher->getWeightDecayParameter ()  << endl;
 76
 77      testNetwork (net, teacher);
 78
 79      delete net;
 80  }
 81
 82
 83  static void
 84  testNetwork (Network *net, Teacher *teacher)
 85  {
 86  #define TRAIN_LOW    (-3.14)
 87  #define TRAIN_HIGH   (3.14)
 88  #define TRAIN_SIZE   100
 89  #define TRAIN_CYLCES 10
 90  #define TRAIN_EPOCH  10
 91  #define TRAIN_LENGTH (TRAIN_EPOCHS * TRAIN_SIZE)
 92
 93      vector <TrainSet*> trainset (TRAIN_SIZE);
 94      vector <double> output;
 95
 96      /* generate train data */
 97      unsigned int n;
 98      fstream trainout ("dump/sin_train_data.dat", ios::out | ios::trunc);
 99      for (n = 0; n < trainset.size (); ++n) {
100          trainset[n] = new TrainSet (1,1);
101          trainset[n]->input->at(0)  = n * ((TRAIN_HIGH - TRAIN_LOW)/TRAIN_SIZE) + TRAIN_LOW;
102          trainset[n]->optout->at(0) = sin (trainset[n]->input->at(0));
103          trainout << trainset[n]->input->at(0) << "\t" << trainset[n]->optout->at(0) << endl;
104      }
105      trainout.close ();
106
107      teacher->setTrainSet (trainset);
108      teacher->saveTrainSet ("trainsets/sin.train");
109      //trainset = teacher->getTrainSet ();
110
111      if (teacher->setErrorFile ("dump/libtest_sin_error.dat"))
112          cout << "   errorfile activated (dump/libtest_sin_error.dat)\n";
113      else
114          cerr << "can't open errorfile, deactivated\n";
115
116      /* train the sin()-function and show output in each cycle*/
117      unsigned int cycle, ne, ni;
118      for (cycle=1; cycle<=TRAIN_CYLCES; cycle++) {
119
120          teacher->shuffleTrainSet ();
121
122          for (ne=0; ne<TRAIN_EPOCH; ne++) {
123              teacher->onlineBackProp ();
124          }
125
126          /* let the network show what is has learned */
127          string filename ("dump/sin_network_output_");
128          char buf[10]; sprintf(buf,"%02u",cycle);
129          filename += buf; filename += ".dat";
130          fstream netout (filename.c_str(), ios::out | ios::trunc);
131
132          double error_sum = 0.0;
133          for (ni = 0; ni < trainset.size (); ++ni) {
134              net->setInput (*trainset[ni]->input);
135              net->setOptimalOutput (*trainset[ni]->optout);
136              net->propagate ();
137              output = net->getOutput ();
138              error_sum += net->errorTerm ();
139
140              netout << trainset[ni]->input->at(0) << "\t" << output[0] << endl;
141          }
142          netout.close ();
143          cout << "overall errorterm of testset: " << error_sum << endl;
144      }
145
146      /* dump the network graph */
147      net->dumpGraph ("dump/sin_network_graph.dot");
148
149      /* save network to file */
150      if (net->save ("dump/sin_network.net"))
151          cout << "network saved to dump/sin_network.net" << endl;
152
153      /* clean up train data */
154      for (n = 0; n < trainset.size (); ++n) {
155          delete trainset[n];
156      }
157  }
158
159