1 /* libtest_spirals
2 * - The task is to learn to discriminate between two sets of training
3 * points which lie on two distinct spirals in the x-y plane.
4 * howto use:
5 * $> ./libtest_spirals
6 * while the network is learning, you can watch the error curve with gnuplot:
7 * $> gnuplot
8 * plot "dump/libtest_spiral_error.dat" with lines
9 * (hit 'a' to refresh)
10 * __________________________________________________________________________*/
11
12 #include <iostream>
13 #include <iomanip>
14 #include "snnl_network.h"
15 #include "snnl_teacher.h"
16
17 void useNetwork (void);
18
19 Network * net;
20 Teacher * teacher;
21
22 int
23 main (int argc, char *argv[])
24 {
25 /* the network description */
26 vector <unsigned int> net_descr;
27 net_descr.push_back (2); // Input
28 net_descr.push_back (20); // Hidden
29 net_descr.push_back (10); // Hidden
30 net_descr.push_back (1); // Output
31
32 /* load network from file or
33 * build new network */
34 if (argc==2 && argv[1][0]=='f') {
35 net = new Network ();
36 teacher = new Teacher (net);
37 if (! net->load ("dump/spiral_network.net")) {
38 cerr << "Can't load network\n";
39 return EXIT_FAILURE;
40 }
41 } else {
42 net = new Network (net_descr);
43
44 /* build and initialize the teacher */
45 teacher = new Teacher (net);
46 teacher->setLearningParameter (0.002);
47 teacher->setMomentumTermParameter (0.7);
48 teacher->setOptimalTolerance (0.16);
49
50 /* randomize weights and theta-values */
51 teacher->randomizeParameters (-0.5, 0.5);
52
53 /* load trainset from file */
54 if (! teacher->loadTrainSet ("trainsets/two-spiral.train")) {
55 cerr << "Can't load trainset\n";
56 return EXIT_FAILURE;
57 }
58
59 /* set errorfile and activate per epoch error output */
60 if (teacher->setErrorFile ("dump/libtest_spiral_error.dat"))
61 cout << " errorfile activated (dump/libtest_spiral_error.dat)\n";
62 else
63 cerr << "can't open errorfile, deactivated\n";
64
65 /* train the trainset data */
66 cout << "train network ...";
67 unsigned int epoch;
68 for (epoch=1; epoch<20000; epoch++) {
69 teacher->shuffleTrainSet ();
70 if (teacher->onlineBackProp () <= 0.005)
71 break;
72 }
73 cout << epoch << " epoch's needed" << endl;
74
75 /* save network */
76 if (net->save ("dump/spiral_network.net"))
77 cout << "network saved to dump/spiral_network.net" << endl;
78 delete teacher;
79 }
80
81 teacher = new Teacher (net);
82
83 /* use the network */
84 if (! teacher->loadTrainSet ("trainsets/two-spiral.test")) {
85 cerr << "Can't load testset\n";
86 return EXIT_FAILURE;
87 }
88 cout << "--- verifing with test-data-set ---" << endl;
89 useNetwork ();
90
91 /* clean up */
92 delete net;
93 delete teacher;
94
95 }
96
97 void useNetwork (void)
98 {
99 vector <TrainSet*> trainset;
100 vector <double> output;
101 trainset = teacher->getTrainSet ();
102 cout << setprecision (2);
103 cout << fixed;
104 cout << "x\ty\toptout\tnetwork\t \tnetwork - analog value\n";
105 unsigned int good = 0;
106 for (unsigned int n=0; n < trainset.size (); n++) {
107 net->setInput (*trainset[n]->input);
108 net->propagate ();
109 output=net->getOutput ();
110
111 unsigned int optout = ((trainset[n]->optout->at(0)==1.0)?1:0);
112 unsigned int out = ((output.at(0)>0.40)?1:0);
113 cout << trainset[n]->input->at(0) << "\t"
114 << trainset[n]->input->at(1) << "\t"
115 << optout << "\t"
116 << out << "\t|\t"
117 << output.at(0) << endl;
118 if (out==optout)
119 good++;
120 }
121 double err = (double)good / trainset.size () * 100.0;
122 cout << "accuracy: " << err << "%" << endl;
123 }
124