1 /* libtest_binclassify.cpp
2 * - classification of binary input values
3 * ---------------------------------------------------------------------------
4 * howto use:
5 * $> ./libtest_sin
6 * process and view the output
7 * $> perl process_dumps.pl
8 * now you can view the coverted output with your favorit viewer
9 * ___________________________________________________________________________*/
10
11 #include <iostream>
12 #include "snnl_network.h"
13 #include "snnl_teacher.h"
14
15 #define TRAIN_SIZE 10
16 #define LO -1.0
17 #define HI 1.0
18
19 Network * net;
20 Teacher * teacher;
21
22 vector <TrainSet*> trainset (TRAIN_SIZE);
23 vector <double> output;
24
25 static void genTrainData ();
26 static void trainNetwork ();
27 static void useNetwork ();
28
29 int
30 main (int argc, char *argv[])
31 {
32 /* the network description and the network */
33 vector <unsigned int> net_descr;
34 net_descr.push_back (4); // input layer (4 neurons)
35 net_descr.push_back (9); // output layer (9 neurons)
36
37 /* build the network and do initializing */
38 net = new Network (net_descr, "binary classification network");
39 net->setActivationFunction (Output, &ActivationFunctions::fact_binary);
40 net->dumpGraph ("dump/libtest_binclassify_before.dot");
41
42 /* create and initialize the teacher */
43 teacher = new Teacher (net);
44 teacher->setLearningParameter (0.3);
45
46 /* initial randomization of the network */
47 teacher->randomizeParameters ();
48
49 /* generate train dataset */
50 genTrainData ();
51 teacher->setTrainSet (trainset);
52 teacher->saveTrainSet ("trainsets/bin_classify.train");
53
54 /* train the network */
55 trainNetwork ();
56
57 /* use the network */
58 useNetwork ();
59 net->dumpGraph ("dump/libtest_binclassify_after.dot");
60
61 /* save the network to disc */
62 net->save ("dump/binnet.net");
63
64 /* clean up */
65 delete net;
66 delete teacher;
67 /*NOTE: the Teacher cleans up automatically the trainset */
68 }
69
70
71 /* Generate train data set
72 * */
73 static void
74 genTrainData ()
75 {
76 for(unsigned int n=0; n<TRAIN_SIZE; n++) {
77 trainset[n] = new TrainSet (4,9);
78
79 for (unsigned int bit=0; bit<4; bit++)
80 trainset[n]->input->at(bit)=LO;
81 for (unsigned int bit=0; bit<9; bit++)
82 trainset[n]->optout->at(bit)=LO;
83 for (unsigned int bit=0; bit<n; bit++)
84 trainset[n]->optout->at(bit)=HI;
85 }
86
87 trainset[1]->input->at(3)=HI;
88 trainset[2]->input->at(2)=HI;
89 trainset[3]->input->at(3)=HI;
90 trainset[3]->input->at(2)=HI;
91 trainset[4]->input->at(1)=HI;
92 trainset[5]->input->at(3)=HI;
93 trainset[5]->input->at(1)=HI;
94 trainset[6]->input->at(2)=HI;
95 trainset[6]->input->at(1)=HI;
96 trainset[7]->input->at(1)=HI;
97 trainset[7]->input->at(2)=HI;
98 trainset[7]->input->at(3)=HI;
99 trainset[8]->input->at(0)=HI;
100 trainset[9]->input->at(3)=HI;
101 trainset[9]->input->at(0)=HI;
102 }
103
104 /*
105 * Train the network using online backpropagation
106 * */
107 static void
108 trainNetwork ()
109 {
110 double eterm_sum;
111 unsigned int epochs;
112
113 /* train the network until it is perfect :) */
114 cout << "train network ..." << endl;
115 /* activate per epoch errorTerm dumping */
116 if (teacher->setErrorFile ("dump/libtest_binclassify_error.dat"))
117 cout << " errorfile activated (dump/libtest_binclassify_error.dat)\n";
118 else
119 cerr << "can't open errorfile, deactivated\n";
120
121 teacher->shuffleTrainSet ();
122 for (epochs=1; ;epochs++) {
123 /* NOTE: you can use online or batch backpropagation */
124 eterm_sum = teacher->onlineBackProp ();
125 //eterm_sum = teacher->batchBackProp ();
126 if (eterm_sum == 0.0)
127 break;
128 }
129 cout << "done, need " << epochs << " epoch's" << endl;
130 }
131
132 static void
133 useNetwork ()
134 {
135 cout << "network output:" << endl;
136 cout << "dec\tinput\topt_output\tnet_output" << endl;
137 for (unsigned int n=0; n < trainset.size (); n++) {
138 cout << n << "\t";
139 for (unsigned int bit=0; bit<4; bit++)
140 cout << ((trainset[n]->input->at(bit)>0.0) ? 1 : 0);
141 cout << "\t";
142 for (unsigned int bit=0; bit<9; bit++)
143 cout << ((trainset[n]->optout->at(bit)>0.0) ? 1 : 0);
144 cout << "\t";
145
146 net->setInput (*trainset[n]->input);
147 net->propagate ();
148 output=net->getOutput ();
149
150 for (unsigned int bit=0; bit<9; bit++)
151 cout << ((output[bit]>0.0) ? 1 : 0);
152 cout << "\n";
153 }
154 }
155