1 /* libtest_xor
2 * - use a trained xor-network (load from file)
3 * howto use
4 * $> ./libtest_xor
5 * $> perl process_dumps.pl <format>
6 * $> <viewer> dump/libtest_xor*
7 * __________________________________________________________________________*/
8
9 #include <iostream>
10 #include "snnl_network.h"
11 #include "snnl_teacher.h"
12
13 #define LO -1.0
14 #define HI 1.0
15
16 Network *net;
17 vector <double>* train_input [4];
18 vector <double> output;
19
20 static void trainNetwork ();
21
22
23 int
24 main (int argc, char *argv[])
25 {
26 /* load xor-network from file */
27 net = new Network ();
28 net->load ("networks/xor_net.net");
29 net->name = "xor snnl network (xor_net.net)";
30
31 /* generate xor input set */
32 for (unsigned int n=0; n<4; n++)
33 train_input [n] = new vector <double> (2);
34
35 train_input[0]->at(0) = LO;
36 train_input[0]->at(1) = LO;
37
38 train_input[1]->at(0) = LO;
39 train_input[1]->at(1) = HI;
40
41 train_input[2]->at(0) = HI;
42 train_input[2]->at(1) = LO;
43
44 train_input[3]->at(0) = HI;
45 train_input[3]->at(1) = HI;
46
47 /* dump network graph (with data set 2)*/
48 net->setInput (*train_input[1]);
49 net->propagate ();
50 net->dumpGraph ("dump/libtest_xor.dot");
51
52 /* use the xor-network */
53 cout << "x\ty\tnet_output\n";
54 for (unsigned int m=0; m<4; m++) {
55 net->setInput (*train_input[m]);
56 net->propagate ();
57 output=net->getOutput ();
58 cout << (train_input[m]->at(0)>0.0 ? 1:0) << "\t"
59 << (train_input[m]->at(1)>0.0 ? 1:0) << "\t"
60 << (output[0]>0.0 ? 1:0) << endl;
61 }
62
63 /* clean up */
64 for(unsigned int n=0; n<4; n++)
65 delete train_input[n];
66 delete net;
67 }
68