1 /* libtest_sin.cpp
2 * - Shows the libsnnl usage by learning the sinus function
3 * --------------------------------------------------------------------------
4 * howto use:
5 * to generate a default network
6 * $> ./libtest_sin
7 * to generate a network with non defaults
8 * $> ./libtest_sin <MomentumTerm> <OptimalTolerance> <WeightDecay>
9 * to load a network from file
10 * $> ./libtest_sin f
11 * process and view the output
12 * $> perl process_dumps.pl
13 * $> gnuplot sin_plot.dem
14 * or all in one
15 * $> ./libtest_sin 0.1 0.0 && ./process_dumps.pl && gnuplot sin_plot.dem
16 * __________________________________________________________________________*/
17
18 #include <iostream>
19 #include <fstream>
20 #include <math.h>
21
22 #include "snnl_network.h"
23 #include "snnl_teacher.h"
24
25
26 static void testNetwork (Network *net, Teacher *teacher);
27
28
29 int
30 main (int argc, char *argv[])
31 {
32 Network * net;
33 Teacher * teacher;
34
35 vector <unsigned int> gen;
36 gen.push_back (1);
37 gen.push_back (8);
38 gen.push_back (6);
39 gen.push_back (1);
40
41 if (argc==2 && argv[1][0]=='f') {
42 /* load a network from file */
43 net = new Network ();
44 teacher = new Teacher (net);
45 net->load ("dump/sin_network.net");
46 } else {
47 /* or generate a new network */
48 net = new Network (gen, "snnl - 'sinus emulator'");
49 teacher = new Teacher (net);
50 teacher->randomizeParameters ();
51 teacher->setLearningParameter (0.2);
52
53 if (argc>=2) {
54 teacher->setMomentumTermParameter (atof(argv[1]));
55 } else {
56 teacher->setMomentumTermParameter (0.0);
57 }
58
59 if (argc>=3) {
60 teacher->setOptimalTolerance (atof(argv[2]));
61 } else {
62 teacher->setOptimalTolerance (0.0);
63 }
64
65 if (argc>=4) {
66 teacher->setWeightDecayParameter (atof(argv[3]));
67 } else {
68 teacher->setWeightDecayParameter (0.00004);
69 }
70 }
71
72 cout << teacher->getLearningParameter () << " "
73 << teacher->getMomentumTermParameter () << " "
74 << teacher->getOptimalTolerance () << " "
75 << teacher->getWeightDecayParameter () << endl;
76
77 testNetwork (net, teacher);
78
79 delete net;
80 }
81
82
83 static void
84 testNetwork (Network *net, Teacher *teacher)
85 {
86 #define TRAIN_LOW (-3.14)
87 #define TRAIN_HIGH (3.14)
88 #define TRAIN_SIZE 100
89 #define TRAIN_CYLCES 10
90 #define TRAIN_EPOCH 10
91 #define TRAIN_LENGTH (TRAIN_EPOCHS * TRAIN_SIZE)
92
93 vector <TrainSet*> trainset (TRAIN_SIZE);
94 vector <double> output;
95
96 /* generate train data */
97 unsigned int n;
98 fstream trainout ("dump/sin_train_data.dat", ios::out | ios::trunc);
99 for (n = 0; n < trainset.size (); ++n) {
100 trainset[n] = new TrainSet (1,1);
101 trainset[n]->input->at(0) = n * ((TRAIN_HIGH - TRAIN_LOW)/TRAIN_SIZE) + TRAIN_LOW;
102 trainset[n]->optout->at(0) = sin (trainset[n]->input->at(0));
103 trainout << trainset[n]->input->at(0) << "\t" << trainset[n]->optout->at(0) << endl;
104 }
105 trainout.close ();
106
107 teacher->setTrainSet (trainset);
108 teacher->saveTrainSet ("trainsets/sin.train");
109 //trainset = teacher->getTrainSet ();
110
111 if (teacher->setErrorFile ("dump/libtest_sin_error.dat"))
112 cout << " errorfile activated (dump/libtest_sin_error.dat)\n";
113 else
114 cerr << "can't open errorfile, deactivated\n";
115
116 /* train the sin()-function and show output in each cycle*/
117 unsigned int cycle, ne, ni;
118 for (cycle=1; cycle<=TRAIN_CYLCES; cycle++) {
119
120 teacher->shuffleTrainSet ();
121
122 for (ne=0; ne<TRAIN_EPOCH; ne++) {
123 teacher->onlineBackProp ();
124 }
125
126 /* let the network show what is has learned */
127 string filename ("dump/sin_network_output_");
128 char buf[10]; sprintf(buf,"%02u",cycle);
129 filename += buf; filename += ".dat";
130 fstream netout (filename.c_str(), ios::out | ios::trunc);
131
132 double error_sum = 0.0;
133 for (ni = 0; ni < trainset.size (); ++ni) {
134 net->setInput (*trainset[ni]->input);
135 net->setOptimalOutput (*trainset[ni]->optout);
136 net->propagate ();
137 output = net->getOutput ();
138 error_sum += net->errorTerm ();
139
140 netout << trainset[ni]->input->at(0) << "\t" << output[0] << endl;
141 }
142 netout.close ();
143 cout << "overall errorterm of testset: " << error_sum << endl;
144 }
145
146 /* dump the network graph */
147 net->dumpGraph ("dump/sin_network_graph.dot");
148
149 /* save network to file */
150 if (net->save ("dump/sin_network.net"))
151 cout << "network saved to dump/sin_network.net" << endl;
152
153 /* clean up train data */
154 for (n = 0; n < trainset.size (); ++n) {
155 delete trainset[n];
156 }
157 }
158
159