1  /* libtest_spirals
  2   *     - The task is to learn to discriminate between two sets of training
  3   *       points which lie on two distinct spirals in the x-y plane.
  4   * howto use:
  5   *     $> ./libtest_spirals
  6   * while the network is learning, you can watch the error curve with gnuplot:
  7   *     $> gnuplot
  8   *        plot "dump/libtest_spiral_error.dat" with lines
  9   *        (hit 'a' to refresh)
 10   * __________________________________________________________________________*/
 11
 12  #include <iostream>
 13  #include <iomanip>
 14  #include "snnl_network.h"
 15  #include "snnl_teacher.h"
 16
 17  void useNetwork (void);
 18
 19  Network * net;
 20  Teacher * teacher;
 21
 22  int
 23  main (int argc, char *argv[])
 24  {
 25      /* the network description */
 26      vector <unsigned int> net_descr;
 27      net_descr.push_back (2);    // Input
 28      net_descr.push_back (20);   // Hidden
 29      net_descr.push_back (10);   // Hidden
 30      net_descr.push_back (1);    // Output
 31
 32      /* load network from file or
 33       * build new network */
 34      if (argc==2 && argv[1][0]=='f') {
 35          net = new Network ();
 36          teacher = new Teacher (net);
 37          if (! net->load ("dump/spiral_network.net")) {
 38              cerr << "Can't load network\n";
 39              return EXIT_FAILURE;
 40          }
 41      } else {
 42          net = new Network (net_descr);
 43
 44          /* build and initialize the teacher */
 45          teacher = new Teacher (net);
 46          teacher->setLearningParameter (0.002);
 47          teacher->setMomentumTermParameter (0.7);
 48          teacher->setOptimalTolerance (0.16);
 49
 50          /* randomize weights and theta-values */
 51          teacher->randomizeParameters (-0.5, 0.5);
 52
 53          /* load trainset from file */
 54          if (! teacher->loadTrainSet ("trainsets/two-spiral.train")) {
 55              cerr << "Can't load trainset\n";
 56              return EXIT_FAILURE;
 57          }
 58
 59          /* set errorfile and activate per epoch error output */
 60          if (teacher->setErrorFile ("dump/libtest_spiral_error.dat"))
 61              cout << "   errorfile activated (dump/libtest_spiral_error.dat)\n";
 62          else
 63              cerr << "can't open errorfile, deactivated\n";
 64
 65          /* train the trainset data */
 66          cout << "train network ...";
 67          unsigned int epoch;
 68          for (epoch=1; epoch<20000; epoch++) {
 69              teacher->shuffleTrainSet ();
 70              if (teacher->onlineBackProp () <= 0.005)
 71                  break;
 72          }
 73          cout << epoch << " epoch's needed" << endl;
 74
 75          /* save network */
 76          if (net->save ("dump/spiral_network.net"))
 77              cout << "network saved to dump/spiral_network.net" << endl;
 78          delete teacher;
 79      }
 80
 81      teacher = new Teacher (net);
 82
 83      /* use the network */
 84      if (! teacher->loadTrainSet ("trainsets/two-spiral.test")) {
 85          cerr << "Can't load testset\n";
 86          return EXIT_FAILURE;
 87      }
 88      cout << "--- verifing with test-data-set ---" << endl;
 89      useNetwork ();
 90
 91      /* clean up */
 92      delete net;
 93      delete teacher;
 94
 95  }
 96
 97  void useNetwork (void)
 98  {
 99      vector <TrainSet*> trainset;
100      vector <double> output;
101      trainset = teacher->getTrainSet ();
102      cout << setprecision (2);
103      cout << fixed;
104      cout << "x\ty\toptout\tnetwork\t \tnetwork - analog value\n";
105      unsigned int good = 0;
106      for (unsigned int n=0; n < trainset.size (); n++) {
107          net->setInput (*trainset[n]->input);
108          net->propagate ();
109          output=net->getOutput ();
110
111          unsigned int optout = ((trainset[n]->optout->at(0)==1.0)?1:0);
112          unsigned int out    = ((output.at(0)>0.40)?1:0);
113          cout << trainset[n]->input->at(0)  << "\t"
114               << trainset[n]->input->at(1)  << "\t"
115               << optout << "\t"
116               << out    << "\t|\t"
117               << output.at(0) << endl;
118          if (out==optout)
119              good++;
120      }
121      double err = (double)good / trainset.size () * 100.0;
122      cout << "accuracy: " << err << "%" << endl;
123  }
124