00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <LanguageTable.h>
00020 #include <TreeBuilder.h>
00021 #include <FunDef.h>
00022 #include <Dataset.h>
00023
00024 #include <eoSymInit.h>
00025 #include <eoSym.h>
00026 #include <eoPop.h>
00027 #include <eoSymMutate.h>
00028
00029 #include <eoSymCrossover.h>
00030 #include <eoSymEval.h>
00031 #include <eoOpContainer.h>
00032 #include <eoDetTournamentSelect.h>
00033 #include <eoMergeReduce.h>
00034 #include <eoGenContinue.h>
00035 #include <eoEasyEA.h>
00036 #include <eoGeneralBreeder.h>
00037
00038 #include <utils/eoParser.h>
00039 #include <utils/eoCheckPoint.h>
00040 #include <utils/eoStat.h>
00041 #include <utils/eoStdoutMonitor.h>
00042 #include <utils/eoRNG.h>
00043
00044 using namespace std;
00045
00046 typedef EoSym<eoMinimizingFitness> EoType;
00047
00048 static int functions_added = 0;
00049
00050 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun);
00051 void setup_language(LanguageTable& table, eoParser& parser);
00052
00053 template <class T>
00054 T& select(bool check, T& a, T& b) { if (check) return a; return b; }
00055
00056 class eoBestIndividualStat : public eoSortedStat<EoType, string> {
00057 public:
00058 eoBestIndividualStat() : eoSortedStat<EoType, string>("", "best individual") {}
00059
00060 void operator()(const vector<const EoType*>& _pop) {
00061 ostringstream os;
00062 os << (Sym) *_pop[0];
00063 value() = os.str();
00064 }
00065
00066 };
00067
00068 class AverageSizeStat : public eoStat<EoType, double> {
00069 public:
00070 AverageSizeStat() : eoStat<EoType, double>(0.0, "Average size population") {}
00071
00072 void operator()(const eoPop<EoType>& _pop) {
00073 double total = 0.0;
00074 for (unsigned i = 0; i < _pop.size(); ++i) {
00075 total += _pop[i].size();
00076 }
00077 value() = total/_pop.size();
00078 }
00079 };
00080
00081 class SumSizeStat : public eoStat<EoType, unsigned> {
00082 public:
00083 SumSizeStat() : eoStat<EoType, unsigned>(0u, "Number of subtrees") {}
00084
00085 void operator()(const eoPop<EoType>& _pop) {
00086 unsigned total = 0;
00087 for (unsigned i = 0; i < _pop.size(); ++i) {
00088 total += _pop[i].size();
00089 }
00090 value() = total;
00091 }
00092 };
00093
00094 class DagSizeStat : public eoStat<EoType, unsigned> {
00095 public:
00096 DagSizeStat() : eoStat<EoType, unsigned>(0u, "Number of distinct subtrees") {}
00097
00098 void operator()(const eoPop<EoType>& _pop) {
00099 value() = Sym::get_dag().size();
00100 }
00101 };
00102
00103 int main(int argc, char* argv[]) {
00104
00105 eoParser parser(argc, argv);
00106
00107
00108 LanguageTable table;
00109 setup_language(table, parser);
00110
00111
00112
00113 eoValueParam<string> datafile = parser.createParam(string(""), "datafile", "Training data", 'd', string("Regression"), true);
00114 double train_percentage = parser.createParam(1.0, "trainperc", "Percentage of data used for training", 0, string("Regression")).value();
00115
00116
00117
00118 unsigned pop_size = parser.createParam(1500u, "population-size", "Population Size", 'p', string("Population")).value();
00119
00120 uint32_t seed = parser.createParam( uint32_t(time(0)), "random-seed", "Seed for rng", 'D').value();
00121
00122 cout << "Seed " << seed << endl;
00123 rng.reseed(seed);
00124
00125 double var_prob = parser.createParam(
00126 0.9,
00127 "var-prob",
00128 "Probability of selecting a var vs. const when creating a terminal",
00129 0,
00130 "Population").value();
00131
00132
00133 double grow_prob = parser.createParam(
00134 0.5,
00135 "grow-prob",
00136 "Probability of selecting 'grow' method instead of 'full' in initialization and mutation",
00137 0,
00138 "Population").value();
00139
00140 unsigned max_depth = parser.createParam(
00141 8u,
00142 "max-depth",
00143 "Maximum depth used in initialization and mutation",
00144 0,
00145 "Population").value();
00146
00147
00148 bool use_uniform = parser.createParam(
00149 false,
00150 "use-uniform",
00151 "Use uniform node selection instead of bias towards internal nodes (functions)",
00152 0,
00153 "Population").value();
00154
00155 double constant_mut_prob = parser.createParam(
00156 0.1,
00157 "constant-mut-rate",
00158 "Probability of performing constant mutation",
00159 0,
00160 "Population").value();
00161
00162
00163 double subtree_mut_prob = parser.createParam(
00164 0.2,
00165 "subtree-mut-rate",
00166 "Probability of performing subtree mutation",
00167 0,
00168 "Population").value();
00169
00170 double node_mut_prob = parser.createParam(
00171 0.2,
00172 "node-mut-rate",
00173 "Probability of performing node mutation",
00174 0,
00175 "Population").value();
00176
00177
00178
00179
00180
00181
00182
00183
00184 double subtree_xover_prob = parser.createParam(
00185 0.4,
00186 "xover-rate",
00187 "Probability of performing subtree crossover",
00188 0,
00189 "Population").value();
00190
00191 double homologous_prob = parser.createParam(
00192 0.4,
00193 "homologous-rate",
00194 "Probability of performing homologous crossover",
00195 0,
00196 "Population").value();
00197
00198 unsigned max_gens = parser.createParam(
00199 50,
00200 "max-gens",
00201 "Maximum number of generations to run",
00202 'g',
00203 "Population").value();
00204
00205 unsigned tournamentsize = parser.createParam(
00206 5,
00207 "tournament-size",
00208 "Tournament size used for selection",
00209 't',
00210 "Population").value();
00211
00212 unsigned maximumSize = parser.createParam(
00213 -1u,
00214 "maximum-size",
00215 "Maximum size after crossover",
00216 's',
00217 "Population").value();
00218
00219 unsigned meas_param = parser.createParam(
00220 2u,
00221 "measure",
00222 "Error measure:\n\
00223 0 -> absolute error\n\
00224 1 -> mean squared error\n\
00225 2 -> mean squared error scaled (equivalent with correlation)\n\
00226 ",
00227 'm',
00228 "Regression").value();
00229
00230
00231 ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled;
00232 if (meas_param == 0) meas = ErrorMeasure::absolute;
00233 if (meas_param == 1) meas = ErrorMeasure::mean_squared;
00234
00235
00236
00237 if (parser.userNeedsHelp())
00238 {
00239 parser.printHelp(std::cout);
00240 return 1;
00241 }
00242
00243 if (functions_added == 0) {
00244 cout << "ERROR: no functions defined" << endl;
00245 exit(1);
00246 }
00247
00248
00249 Dataset dataset;
00250 dataset.load_data(datafile.value());
00251
00252 cout << "Data " << datafile.value() << " loaded " << endl;
00253
00254
00255 unsigned nvars = dataset.n_fields();
00256 for (unsigned i = 0; i < nvars; ++i) {
00257 table.add_function( SymVar(i).token(), 0);
00258 }
00259
00260 TreeBuilder builder(table, var_prob);
00261 eoSymInit<EoType> init(builder, grow_prob, max_depth);
00262
00263 eoPop<EoType> pop(pop_size, init);
00264
00265 BiasedNodeSelector biased_sel;
00266 RandomNodeSelector random_sel;
00267
00268 NodeSelector& node_selector = select<NodeSelector>(use_uniform, random_sel, biased_sel);
00269
00270
00271 eoSequentialOp<EoType> genetic_operator;
00272
00273 eoSymSubtreeMutate<EoType> submutate(builder, node_selector);
00274 genetic_operator.add( submutate, subtree_mut_prob);
00275
00276
00277 double std = 0.01;
00278 eoSymConstantMutate<EoType> constmutate(std);
00279 genetic_operator.add(constmutate, constant_mut_prob);
00280
00281 eoSymNodeMutate<EoType> nodemutate(table);
00282 genetic_operator.add(nodemutate, node_mut_prob);
00283
00284
00285
00286
00287
00288 eoBinSubtreeCrossover<EoType> bin(node_selector);
00289 genetic_operator.add(bin, subtree_xover_prob);
00290
00291 eoBinHomologousCrossover<EoType> hom;
00292 genetic_operator.add(hom, homologous_prob);
00293
00294
00295 IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
00296 ErrorMeasure measure(dataset, train_percentage, meas);
00297
00298 eoSymPopEval<EoType> evaluator(check, measure, maximumSize);
00299
00300 eoDetTournamentSelect<EoType> selectOne(tournamentsize);
00301 eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator);
00302 eoPlusReplacement<EoType> replace;
00303
00304
00305 eoGenContinue<EoType> term(max_gens);
00306 eoCheckPoint<EoType> checkpoint(term);
00307
00308 eoBestFitnessStat<EoType> beststat;
00309 checkpoint.add(beststat);
00310
00311 eoBestIndividualStat printer;
00312 AverageSizeStat avgSize;
00313 DagSizeStat dagSize;
00314 SumSizeStat sumSize;
00315
00316 checkpoint.add(printer);
00317 checkpoint.add(avgSize);
00318 checkpoint.add(dagSize);
00319 checkpoint.add(sumSize);
00320
00321 eoStdoutMonitor genmon;
00322 genmon.add(beststat);
00323 genmon.add(printer);
00324 genmon.add(avgSize);
00325 genmon.add(dagSize);
00326 genmon.add(sumSize);
00327 genmon.add(term);
00328
00329 checkpoint.add(genmon);
00330
00331 eoPop<EoType> dummy;
00332 evaluator(pop, dummy);
00333
00334 eoEasyEA<EoType> ea(checkpoint, evaluator, breeder, replace);
00335
00336 ea(pop);
00337
00338 }
00339
00340 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun, bool all) {
00341 ostringstream desc;
00342 desc << "Enable function " << name << " arity = " << arity;
00343 bool enabled = parser.createParam(false, name, desc.str(), 0, "Language").value();
00344
00345 if (enabled || all) {
00346 cout << "Func " << name << " enabled" << endl;
00347 table.add_function(token, arity);
00348 if (arity > 0) functions_added++;
00349 }
00350 }
00351
00352 void setup_language(LanguageTable& table, eoParser& parser) {
00353
00354 bool all = parser.createParam(false,"all", "Enable all functions").value();
00355 bool ratio = parser.createParam(false,"ratio","Enable rational functions (inv,min,sum,prod)").value();
00356 bool poly = parser.createParam(false,"poly","Enable polynomial functions (min,sum,prod)").value();
00357
00358
00359 vector<const FunDef*> lang = get_defined_functions();
00360
00361 for (token_t i = 0; i < lang.size(); ++i) {
00362
00363 if (lang[i] == 0) continue;
00364
00365 bool is_poly = false;
00366 if (poly && (i == prod_token || i == sum_token || i == min_token) ) {
00367 is_poly = true;
00368 }
00369
00370 bool is_ratio = false;
00371 if (ratio && (is_poly || i == inv_token)) {
00372 is_ratio = true;
00373 }
00374
00375 const FunDef& fun = *lang[i];
00376
00377 if (fun.has_varargs() ) {
00378
00379 for (unsigned j = fun.min_arity(); j < fun.min_arity() + 8; ++j) {
00380 if (j==1) continue;
00381 ostringstream nm;
00382 nm << fun.name() << j;
00383 bool addanyway = (all || is_ratio || is_poly) && j == 2;
00384 add_function(table, parser, nm.str(), j, i, fun, addanyway);
00385 }
00386 }
00387 else {
00388 add_function(table, parser, fun.name(), fun.min_arity(), i, fun, all || is_ratio || is_poly);
00389 }
00390 }
00391 }
00392
00393