1  /* libtest_xor
 2   *     - use a trained xor-network (load from file)
 3   * howto use
 4   *     $> ./libtest_xor
 5   *     $> perl process_dumps.pl <format>
 6   *     $> <viewer> dump/libtest_xor*
 7   * __________________________________________________________________________*/
 8
 9  #include <iostream>
10  #include "snnl_network.h"
11  #include "snnl_teacher.h"
12
13  #define LO -1.0
14  #define HI 1.0
15
16  Network *net;
17  vector <double>* train_input [4];
18  vector <double>  output;
19
20  static void trainNetwork ();
21
22
23  int
24  main (int argc, char *argv[])
25  {
26      /* load xor-network from file */
27      net = new Network ();
28      net->load ("networks/xor_net.net");
29      net->name = "xor snnl network (xor_net.net)";
30
31      /* generate xor input set */
32      for (unsigned int n=0; n<4; n++)
33          train_input [n] = new vector <double> (2);
34
35      train_input[0]->at(0) = LO;
36      train_input[0]->at(1) = LO;
37
38      train_input[1]->at(0) = LO;
39      train_input[1]->at(1) = HI;
40
41      train_input[2]->at(0) = HI;
42      train_input[2]->at(1) = LO;
43
44      train_input[3]->at(0) = HI;
45      train_input[3]->at(1) = HI;
46
47      /* dump network graph (with data set 2)*/
48      net->setInput (*train_input[1]);
49      net->propagate ();
50      net->dumpGraph ("dump/libtest_xor.dot");
51
52      /* use the xor-network */
53      cout << "x\ty\tnet_output\n";
54      for (unsigned int m=0; m<4; m++) {
55          net->setInput (*train_input[m]);
56          net->propagate ();
57          output=net->getOutput ();
58          cout << (train_input[m]->at(0)>0.0 ? 1:0) << "\t"
59               << (train_input[m]->at(1)>0.0 ? 1:0) << "\t"
60               << (output[0]>0.0 ? 1:0)             << endl;
61      }
62
63      /* clean up */
64      for(unsigned int n=0; n<4; n++)
65          delete train_input[n];
66      delete net;
67  }
68