symreg.cpp

00001 /*          
00002  *             Copyright (C) 2005 Maarten Keijzer
00003  *
00004  *          This program is free software; you can redistribute it and/or modify
00005  *          it under the terms of version 2 of the GNU General Public License as 
00006  *          published by the Free Software Foundation. 
00007  *
00008  *          This program is distributed in the hope that it will be useful,
00009  *          but WITHOUT ANY WARRANTY; without even the implied warranty of
00010  *          MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00011  *          GNU General Public License for more details.
00012  *
00013  *          You should have received a copy of the GNU General Public License
00014  *          along with this program; if not, write to the Free Software
00015  *          Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
00016  */
00017 
00018 
00019 #include <LanguageTable.h>
00020 #include <TreeBuilder.h>
00021 #include <FunDef.h>
00022 #include <Dataset.h>
00023 
00024 #include <eoSymInit.h>
00025 #include <eoSym.h>
00026 #include <eoPop.h>
00027 #include <eoSymMutate.h>
00028 //#include <eoSymLambdaMutate.h>
00029 #include <eoSymCrossover.h>
00030 #include <eoSymEval.h>
00031 #include <eoOpContainer.h>
00032 #include <eoDetTournamentSelect.h>
00033 #include <eoMergeReduce.h>
00034 #include <eoGenContinue.h>
00035 #include <eoEasyEA.h>
00036 #include <eoGeneralBreeder.h>
00037 
00038 #include <utils/eoParser.h>
00039 #include <utils/eoCheckPoint.h>
00040 #include <utils/eoStat.h>
00041 #include <utils/eoStdoutMonitor.h>
00042 #include <utils/eoRNG.h>
00043 
00044 using namespace std;
00045 
00046 typedef EoSym<eoMinimizingFitness> EoType;
00047 
00048 static int functions_added = 0;
00049 
00050 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun); 
00051 void setup_language(LanguageTable& table, eoParser& parser);
00052 
00053 template <class T>
00054 T& select(bool check, T& a, T& b) { if (check) return a; return b; }
00055 
00056 class eoBestIndividualStat : public eoSortedStat<EoType, string> {
00057     public: 
00058     eoBestIndividualStat() : eoSortedStat<EoType, string>("", "best individual") {}
00059     
00060     void operator()(const vector<const EoType*>& _pop)  {
00061         ostringstream os;
00062         os << (Sym) *_pop[0];
00063         value() = os.str();
00064     }
00065     
00066 };
00067 
00068 class AverageSizeStat : public eoStat<EoType, double> {
00069     public:
00070         AverageSizeStat() : eoStat<EoType, double>(0.0, "Average size population") {}
00071     
00072         void operator()(const eoPop<EoType>& _pop) {
00073             double total = 0.0;
00074             for (unsigned i = 0; i < _pop.size(); ++i) {
00075                 total += _pop[i].size();
00076             }
00077             value() = total/_pop.size();
00078         }
00079 };
00080 
00081 class SumSizeStat : public eoStat<EoType, unsigned> {
00082     public:
00083         SumSizeStat() : eoStat<EoType, unsigned>(0u, "Number of subtrees") {}
00084     
00085         void operator()(const eoPop<EoType>& _pop) {
00086             unsigned total = 0;
00087             for (unsigned i = 0; i < _pop.size(); ++i) {
00088                 total += _pop[i].size();
00089             }
00090             value() = total;
00091         }
00092 };
00093 
00094 class DagSizeStat : public eoStat<EoType, unsigned> {
00095     public:
00096         DagSizeStat() : eoStat<EoType, unsigned>(0u, "Number of distinct subtrees") {}
00097 
00098         void operator()(const eoPop<EoType>& _pop) {
00099             value() = Sym::get_dag().size();
00100         }
00101 };
00102 
00103 int main(int argc, char* argv[]) {
00104    
00105     eoParser parser(argc, argv);
00106   
00107     /* Language */
00108     LanguageTable table;
00109     setup_language(table, parser);
00110    
00111     /* Data */
00112     
00113     eoValueParam<string> datafile = parser.createParam(string(""), "datafile", "Training data", 'd', string("Regression"), true); // mandatory 
00114     double train_percentage = parser.createParam(1.0, "trainperc", "Percentage of data used for training", 0, string("Regression")).value();
00115     
00116     /* Population */
00117 
00118     unsigned pop_size = parser.createParam(1500u, "population-size", "Population Size", 'p', string("Population")).value();
00119   
00120     uint32_t seed = parser.createParam( uint32_t(time(0)), "random-seed", "Seed for rng", 'D').value();
00121 
00122     cout << "Seed " << seed << endl;
00123     rng.reseed(seed);
00124     
00125     double var_prob = parser.createParam(
00126             0.9, 
00127             "var-prob", 
00128             "Probability of selecting a var vs. const when creating a terminal",
00129             0,
00130             "Population").value();
00131 
00132     
00133     double grow_prob = parser.createParam(
00134             0.5,
00135             "grow-prob",
00136             "Probability of selecting 'grow' method instead of 'full' in initialization and mutation",
00137             0,
00138             "Population").value();
00139     
00140     unsigned max_depth = parser.createParam(
00141             8u,
00142             "max-depth",
00143             "Maximum depth used in initialization and mutation",
00144             0,
00145             "Population").value();
00146             
00147    
00148     bool use_uniform = parser.createParam(
00149             false,
00150             "use-uniform",
00151             "Use uniform node selection instead of bias towards internal nodes (functions)",
00152             0,
00153             "Population").value();
00154             
00155     double constant_mut_prob = parser.createParam(
00156             0.1,
00157             "constant-mut-rate",
00158             "Probability of performing constant mutation",
00159             0,
00160             "Population").value();
00161     
00162     
00163     double subtree_mut_prob = parser.createParam(
00164             0.2,
00165             "subtree-mut-rate",
00166             "Probability of performing subtree mutation",
00167             0,
00168             "Population").value();
00169     
00170     double node_mut_prob = parser.createParam(
00171             0.2,
00172             "node-mut-rate",
00173             "Probability of performing node mutation",
00174             0,
00175             "Population").value();
00176     
00177 /*    double lambda_mut_prob = parser.createParam(
00178             1.0,
00179             "lambda-mut-rate",
00180             "Probability of performing (neutral) lambda extraction/expansion",
00181             0,
00182             "Population").value();
00183 */
00184     double subtree_xover_prob = parser.createParam(
00185             0.4,
00186             "xover-rate",
00187             "Probability of performing subtree crossover",
00188             0,
00189             "Population").value();
00190 
00191     double homologous_prob = parser.createParam(
00192             0.4,
00193             "homologous-rate",
00194             "Probability of performing homologous crossover",
00195             0,
00196             "Population").value();
00197 
00198     unsigned max_gens = parser.createParam(
00199             50,
00200             "max-gens",
00201             "Maximum number of generations to run",
00202             'g',
00203             "Population").value();
00204     
00205     unsigned tournamentsize = parser.createParam(
00206             5,
00207             "tournament-size",
00208             "Tournament size used for selection",
00209             't',
00210             "Population").value();
00211 
00212     unsigned maximumSize = parser.createParam(
00213             -1u,
00214             "maximum-size",
00215             "Maximum size after crossover",
00216             's',
00217             "Population").value();
00218     
00219     unsigned meas_param = parser.createParam(
00220             2u,
00221             "measure",
00222             "Error measure:\n\
00223                 0 -> absolute error\n\
00224                 1 -> mean squared error\n\
00225                 2 -> mean squared error scaled (equivalent with correlation)\n\
00226                 ",
00227                 'm',
00228                 "Regression").value();
00229   
00230     
00231     ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled;
00232     if (meas_param == 0) meas = ErrorMeasure::absolute;
00233     if (meas_param == 1) meas = ErrorMeasure::mean_squared;
00234 
00235     
00236     /* End parsing */
00237     if (parser.userNeedsHelp())
00238     {
00239         parser.printHelp(std::cout);
00240         return 1;
00241     }
00242     
00243     if (functions_added == 0) {
00244         cout << "ERROR: no functions defined" << endl;
00245         exit(1);
00246     }
00247     
00248     
00249     Dataset dataset;
00250     dataset.load_data(datafile.value());
00251     
00252     cout << "Data " << datafile.value() << " loaded " << endl;
00253    
00254     /* Add Variables */
00255     unsigned nvars = dataset.n_fields();
00256     for (unsigned i = 0; i < nvars; ++i) {
00257         table.add_function( SymVar(i).token(), 0);
00258     }
00259     
00260     TreeBuilder builder(table, var_prob);
00261     eoSymInit<EoType> init(builder, grow_prob, max_depth);
00262     
00263     eoPop<EoType> pop(pop_size, init);
00264     
00265     BiasedNodeSelector biased_sel;
00266     RandomNodeSelector random_sel;
00267 
00268     NodeSelector& node_selector = select<NodeSelector>(use_uniform, random_sel, biased_sel);
00269     
00270     //eoProportionalOp<EoType> genetic_operator;
00271     eoSequentialOp<EoType> genetic_operator;
00272     
00273     eoSymSubtreeMutate<EoType> submutate(builder, node_selector);
00274     genetic_operator.add( submutate, subtree_mut_prob);
00275    
00276     // todo, make this parameter, etc
00277     double std = 0.01;
00278     eoSymConstantMutate<EoType> constmutate(std);
00279     genetic_operator.add(constmutate, constant_mut_prob);
00280     
00281     eoSymNodeMutate<EoType>    nodemutate(table);
00282     genetic_operator.add(nodemutate, node_mut_prob);
00283    
00284 //    eoSymLambdaMutate<EoType> lambda_mutate(node_selector);
00285 //    genetic_operator.add(lambda_mutate, lambda_mut_prob); // TODO: prob should be settable
00286     
00287     //eoQuadSubtreeCrossover<EoType> quad(node_selector);
00288     eoBinSubtreeCrossover<EoType> bin(node_selector);
00289     genetic_operator.add(bin, subtree_xover_prob);
00290     
00291     eoBinHomologousCrossover<EoType> hom;
00292     genetic_operator.add(hom, homologous_prob);
00293 
00294 
00295     IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
00296     ErrorMeasure measure(dataset, train_percentage, meas);
00297 
00298     eoSymPopEval<EoType> evaluator(check, measure, maximumSize);
00299     
00300     eoDetTournamentSelect<EoType> selectOne(tournamentsize);
00301     eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator);
00302     eoPlusReplacement<EoType> replace;
00303 
00304     // Terminators
00305     eoGenContinue<EoType> term(max_gens);
00306     eoCheckPoint<EoType> checkpoint(term);
00307     
00308     eoBestFitnessStat<EoType> beststat;
00309     checkpoint.add(beststat);
00310    
00311     eoBestIndividualStat printer;
00312     AverageSizeStat avgSize;
00313     DagSizeStat dagSize;
00314     SumSizeStat sumSize;
00315     
00316     checkpoint.add(printer);
00317     checkpoint.add(avgSize);
00318     checkpoint.add(dagSize);
00319     checkpoint.add(sumSize);
00320     
00321     eoStdoutMonitor genmon;
00322     genmon.add(beststat);
00323     genmon.add(printer);
00324     genmon.add(avgSize);
00325     genmon.add(dagSize);
00326     genmon.add(sumSize);
00327     genmon.add(term); // add generation counter
00328     
00329     checkpoint.add(genmon);
00330     
00331     eoPop<EoType> dummy;
00332     evaluator(pop, dummy);
00333     
00334     eoEasyEA<EoType> ea(checkpoint, evaluator, breeder, replace);
00335 
00336     ea(pop); // run
00337     
00338 }
00339 
00340 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun, bool all) {
00341     ostringstream desc;
00342     desc << "Enable function " << name << " arity = " << arity;
00343     bool enabled = parser.createParam(false, name, desc.str(), 0, "Language").value();
00344     
00345     if (enabled || all) {
00346         cout << "Func " << name << " enabled" << endl;
00347         table.add_function(token, arity);
00348         if (arity > 0) functions_added++;
00349     }
00350 }
00351 
00352 void setup_language(LanguageTable& table, eoParser& parser) {
00353 
00354     bool all = parser.createParam(false,"all", "Enable all functions").value();
00355     bool ratio = parser.createParam(false,"ratio","Enable rational functions (inv,min,sum,prod)").value();
00356     bool poly = parser.createParam(false,"poly","Enable polynomial functions (min,sum,prod)").value();
00357     
00358     // assumes that at this point all tokens are defined (none are zeroed out, which can happen with ERCs)
00359     vector<const FunDef*> lang = get_defined_functions(); 
00360     
00361     for (token_t i = 0; i < lang.size(); ++i) {
00362         
00363         if (lang[i] == 0) continue;
00364         
00365         bool is_poly = false;
00366         if (poly && (i == prod_token || i == sum_token || i == min_token) ) {
00367             is_poly = true;
00368         }
00369         
00370         bool is_ratio = false;
00371         if (ratio && (is_poly || i == inv_token)) {
00372             is_ratio = true;
00373         }
00374         
00375         const FunDef& fun = *lang[i]; 
00376         
00377         if (fun.has_varargs() ) {
00378 
00379             for (unsigned j = fun.min_arity(); j < fun.min_arity() + 8; ++j) {
00380                 if (j==1) continue; // prod 1 and sum 1 are useless
00381                 ostringstream nm;
00382                 nm << fun.name() << j;
00383                 bool addanyway = (all || is_ratio || is_poly) && j == 2;
00384                 add_function(table, parser, nm.str(), j, i, fun, addanyway);
00385             }
00386         }
00387         else {
00388             add_function(table, parser, fun.name(), fun.min_arity(), i, fun, all || is_ratio || is_poly);
00389         }
00390     }
00391 }
00392 
00393 

Generated on Thu Oct 19 05:06:43 2006 for EO by  doxygen 1.3.9.1