1  /* libtest_encoder-decoder.cpp
 2   *     - This network learns to encode 8 bits to 3 bits and decode
 3   *       3 bits to 8 bits
 4   * ---------------------------------------------------------------------------
 5   * howto use:
 6   *    $> ./libtest_encoder_decoder
 7   * to see the errorterm curve:
 8   *    $> gnuplot
 9   *       gnuplot> plot "dump/libtest_encoder_error.dat" with lines
10   * ___________________________________________________________________________*/
11
12  #include <iostream>
13  #include "snnl_network.h"
14  #include "snnl_teacher.h"
15
16
17  int
18  main (int argc, char *argv[])
19  {
20      vector <TrainSet*> trainset;
21      vector <double> output;
22
23      /* the network description and the network */
24      vector <unsigned int> net_descr;
25      Network * net;
26      Teacher * teacher;
27
28      net_descr.push_back (8);    // Input  - 8 Bits
29      net_descr.push_back (3);    // Hidden - 3 Bits
30      net_descr.push_back (8);    // Output - 8 Bits
31
32      /* build the network and initialize it  */
33      net = new Network (net_descr);
34      net->setActivationFunction (Output, &ActivationFunctions::fact_binary);
35
36      /* build and initialize the teacher */
37      teacher = new Teacher (net);
38      teacher->setLearningParameter (0.1);
39      //teacher->setMomentumTermParameter (0.5);
40      teacher->setWeightDecayParameter (0.001);
41      //teacher->setOptimalTolerance (0.05);
42
43      /* randomize weights and theta-values */
44      teacher->randomizeParameters ();
45
46      /* load trainset from file */
47      if (! teacher->loadTrainSet ("trainsets/encoder_8to8.train")) {
48          cerr << "Can't load trainset\n";
49          return EXIT_FAILURE;
50      }
51
52      /* set errorfile and activate per epoch error output */
53      if (teacher->setErrorFile ("dump/libtest_encoder_error.dat"))
54          cout << "   errorfile activated (dump/libtest_encoder_error.dat)\n";
55      else
56          cerr << "can't open errorfile, deactivated\n";
57
58      /* train the trainset data */
59      cout << "train network ..." << endl;
60      unsigned int epoch;
61      for (epoch=1; epoch<2000; epoch++) {
62          teacher->shuffleTrainSet ();
63          if (teacher->onlineBackProp () == 0.0)
64          //if (teacher->batchBackProp () == 0.0)
65              break;
66      }
67      cout << "need " << epoch << " epoch's" << endl;
68
69      /* use the trained network */
70      trainset = teacher->getTrainSet ();
71      cout << "network output:" << endl;
72      cout << "--------+-------+--------" << endl;
73      for (unsigned int n=0; n < trainset.size (); n++) {
74          for (unsigned int bit=0; bit<8; bit++)
75              cout << ((trainset[n]->input->at(bit)>0.0) ? 1 : 0);
76          cout << "|->-3->-|";
77
78          net->setInput (*trainset[n]->input);
79          net->propagate ();
80          output=net->getOutput ();
81
82          for (unsigned int bit=0; bit<8; bit++)
83              cout << ((output.at(bit)>0.0) ? 1 : 0);
84          cout << endl;
85      }
86
87      /* clean up after usage */
88      delete net;
89      delete teacher;
90
91      return EXIT_SUCCESS;
92  }