1  /* libtest_binclassify.cpp
  2   *     - classification of binary input values
  3   * ---------------------------------------------------------------------------
  4   * howto use:
  5   *    $> ./libtest_sin
  6   * process and view the output
  7   *    $> perl process_dumps.pl
  8   *    now you can view the coverted output with your favorit viewer
  9   * ___________________________________________________________________________*/
 10
 11  #include <iostream>
 12  #include "snnl_network.h"
 13  #include "snnl_teacher.h"
 14
 15  #define TRAIN_SIZE 10
 16  #define LO -1.0
 17  #define HI  1.0
 18
 19  Network * net;
 20  Teacher * teacher;
 21
 22  vector <TrainSet*> trainset (TRAIN_SIZE);
 23  vector <double> output;
 24
 25  static void genTrainData ();
 26  static void trainNetwork ();
 27  static void useNetwork   ();
 28
 29  int
 30  main (int argc, char *argv[])
 31  {
 32      /* the network description and the network */
 33      vector <unsigned int> net_descr;
 34      net_descr.push_back (4);    // input  layer (4 neurons)
 35      net_descr.push_back (9);    // output layer (9 neurons)
 36
 37      /* build the network and do initializing */
 38      net = new Network (net_descr, "binary classification network");
 39      net->setActivationFunction (Output, &ActivationFunctions::fact_binary);
 40      net->dumpGraph ("dump/libtest_binclassify_before.dot");
 41
 42      /* create and initialize the teacher */
 43      teacher = new Teacher (net);
 44      teacher->setLearningParameter (0.3);
 45
 46      /* initial randomization of the network */
 47      teacher->randomizeParameters ();
 48
 49      /* generate train dataset */
 50      genTrainData ();
 51      teacher->setTrainSet (trainset);
 52      teacher->saveTrainSet ("trainsets/bin_classify.train");
 53
 54      /* train the network */
 55      trainNetwork ();
 56
 57      /* use the network */
 58      useNetwork   ();
 59      net->dumpGraph ("dump/libtest_binclassify_after.dot");
 60
 61      /* save the network to disc */
 62      net->save ("dump/binnet.net");
 63
 64      /* clean up */
 65      delete net;
 66      delete teacher;
 67      /*NOTE: the Teacher cleans up automatically the trainset */
 68  }
 69
 70
 71  /* Generate train data set
 72   * */
 73  static void
 74  genTrainData ()
 75  {
 76      for(unsigned int n=0; n<TRAIN_SIZE; n++) {
 77          trainset[n] = new TrainSet (4,9);
 78
 79          for (unsigned int bit=0; bit<4; bit++)
 80              trainset[n]->input->at(bit)=LO;
 81          for (unsigned int bit=0; bit<9; bit++)
 82              trainset[n]->optout->at(bit)=LO;
 83          for (unsigned int bit=0; bit<n; bit++)
 84              trainset[n]->optout->at(bit)=HI;
 85      }
 86
 87      trainset[1]->input->at(3)=HI;
 88      trainset[2]->input->at(2)=HI;
 89      trainset[3]->input->at(3)=HI;
 90      trainset[3]->input->at(2)=HI;
 91      trainset[4]->input->at(1)=HI;
 92      trainset[5]->input->at(3)=HI;
 93      trainset[5]->input->at(1)=HI;
 94      trainset[6]->input->at(2)=HI;
 95      trainset[6]->input->at(1)=HI;
 96      trainset[7]->input->at(1)=HI;
 97      trainset[7]->input->at(2)=HI;
 98      trainset[7]->input->at(3)=HI;
 99      trainset[8]->input->at(0)=HI;
100      trainset[9]->input->at(3)=HI;
101      trainset[9]->input->at(0)=HI;
102  }
103
104  /*
105   * Train the network using online backpropagation
106   * */
107  static void
108  trainNetwork ()
109  {
110      double eterm_sum;
111      unsigned int   epochs;
112
113      /* train the network until it is perfect :) */
114      cout << "train network ..." << endl;
115      /* activate per epoch errorTerm dumping */
116      if (teacher->setErrorFile ("dump/libtest_binclassify_error.dat"))
117          cout << "   errorfile activated (dump/libtest_binclassify_error.dat)\n";
118      else
119          cerr << "can't open errorfile, deactivated\n";
120
121      teacher->shuffleTrainSet ();
122      for (epochs=1; ;epochs++) {
123          /* NOTE: you can use online or batch backpropagation */
124          eterm_sum = teacher->onlineBackProp ();
125          //eterm_sum = teacher->batchBackProp ();
126          if (eterm_sum == 0.0)
127              break;
128      }
129      cout << "done, need " << epochs << " epoch's" << endl;
130  }
131
132  static void
133  useNetwork ()
134  {
135      cout << "network output:" << endl;
136      cout << "dec\tinput\topt_output\tnet_output" << endl;
137      for (unsigned int n=0; n < trainset.size (); n++) {
138          cout << n << "\t";
139          for (unsigned int bit=0; bit<4; bit++)
140              cout << ((trainset[n]->input->at(bit)>0.0) ? 1 : 0);
141          cout << "\t";
142          for (unsigned int bit=0; bit<9; bit++)
143              cout << ((trainset[n]->optout->at(bit)>0.0) ? 1 : 0);
144          cout << "\t";
145
146          net->setInput (*trainset[n]->input);
147          net->propagate ();
148          output=net->getOutput ();
149
150          for (unsigned int bit=0; bit<9; bit++)
151              cout << ((output[bit]>0.0) ? 1 : 0);
152          cout << "\n";
153      }
154  }
155