1 /* libtest_encoder-decoder.cpp
2 * - This network learns to encode 8 bits to 3 bits and decode
3 * 3 bits to 8 bits
4 * ---------------------------------------------------------------------------
5 * howto use:
6 * $> ./libtest_encoder_decoder
7 * to see the errorterm curve:
8 * $> gnuplot
9 * gnuplot> plot "dump/libtest_encoder_error.dat" with lines
10 * ___________________________________________________________________________*/
11
12 #include <iostream>
13 #include "snnl_network.h"
14 #include "snnl_teacher.h"
15
16
17 int
18 main (int argc, char *argv[])
19 {
20 vector <TrainSet*> trainset;
21 vector <double> output;
22
23 /* the network description and the network */
24 vector <unsigned int> net_descr;
25 Network * net;
26 Teacher * teacher;
27
28 net_descr.push_back (8); // Input - 8 Bits
29 net_descr.push_back (3); // Hidden - 3 Bits
30 net_descr.push_back (8); // Output - 8 Bits
31
32 /* build the network and initialize it */
33 net = new Network (net_descr);
34 net->setActivationFunction (Output, &ActivationFunctions::fact_binary);
35
36 /* build and initialize the teacher */
37 teacher = new Teacher (net);
38 teacher->setLearningParameter (0.1);
39 //teacher->setMomentumTermParameter (0.5);
40 teacher->setWeightDecayParameter (0.001);
41 //teacher->setOptimalTolerance (0.05);
42
43 /* randomize weights and theta-values */
44 teacher->randomizeParameters ();
45
46 /* load trainset from file */
47 if (! teacher->loadTrainSet ("trainsets/encoder_8to8.train")) {
48 cerr << "Can't load trainset\n";
49 return EXIT_FAILURE;
50 }
51
52 /* set errorfile and activate per epoch error output */
53 if (teacher->setErrorFile ("dump/libtest_encoder_error.dat"))
54 cout << " errorfile activated (dump/libtest_encoder_error.dat)\n";
55 else
56 cerr << "can't open errorfile, deactivated\n";
57
58 /* train the trainset data */
59 cout << "train network ..." << endl;
60 unsigned int epoch;
61 for (epoch=1; epoch<2000; epoch++) {
62 teacher->shuffleTrainSet ();
63 if (teacher->onlineBackProp () == 0.0)
64 //if (teacher->batchBackProp () == 0.0)
65 break;
66 }
67 cout << "need " << epoch << " epoch's" << endl;
68
69 /* use the trained network */
70 trainset = teacher->getTrainSet ();
71 cout << "network output:" << endl;
72 cout << "--------+-------+--------" << endl;
73 for (unsigned int n=0; n < trainset.size (); n++) {
74 for (unsigned int bit=0; bit<8; bit++)
75 cout << ((trainset[n]->input->at(bit)>0.0) ? 1 : 0);
76 cout << "|->-3->-|";
77
78 net->setInput (*trainset[n]->input);
79 net->propagate ();
80 output=net->getOutput ();
81
82 for (unsigned int bit=0; bit<8; bit++)
83 cout << ((output.at(bit)>0.0) ? 1 : 0);
84 cout << endl;
85 }
86
87 /* clean up after usage */
88 delete net;
89 delete teacher;
90
91 return EXIT_SUCCESS;
92 }