MultiFunction.cpp

00001 #include <vector.h>
00002 
00003 
00004 #include "MultiFunction.h"
00005 #include "Sym.h"
00006 #include "FunDef.h"
00007 
00008 using namespace std;
00009 
00010 
00011 typedef vector<double>::const_iterator data_ptr;
00012 typedef vector<data_ptr> data_ptrs;
00013 typedef data_ptrs::const_iterator arg_ptr;
00014 
00015 #include "MultiFuncs.cpp"
00016 
00017 typedef double (*fptr)( arg_ptr );
00018 
00019 string print_function( fptr f) {
00020     if (f == multi_function::plus) return "+";
00021     if (f == multi_function::mult) return "*";
00022     if (f == multi_function::min) return "-";
00023     if (f == multi_function::inv) return "/";
00024     if (f == multi_function::exp) return "e";
00025     return "unknown";
00026 }
00027 
00028 
00029 struct Function {
00030     
00031     fptr function;
00032     arg_ptr args;
00033 
00034     double operator()() const { return function(args); }
00035 };
00036 
00037 static vector<Function> token_2_function;
00038 
00039 Sym make_binary(Sym sym) {
00040     if (sym.args().size() == 2) return sym;
00041     SymVec args = sym.args();
00042     Sym an = args.back();
00043     args.pop_back();
00044     Sym nw = make_binary( Sym( sym.token(), args) );
00045     args.resize(2);
00046     args[0] = nw;
00047     args[1] = an;
00048     return Sym(sym.token(), args); 
00049 }
00050 
00051 class Compiler {
00052     
00053     public:
00054 
00055     enum func_type {constant, variable, function};
00056     
00057     typedef pair<func_type, unsigned> entry;
00058 
00059 #if USE_TR1
00060             typedef std::tr1::unordered_map<Sym, entry, HashSym> HashMap;
00061 #else
00062             typedef hash_map<Sym, entry, HashSym> HashMap;
00063 #endif      
00064    
00065     HashMap map;
00066    
00067     vector<double> constants;
00068     vector<unsigned> variables;
00069     vector< fptr > functions;
00070     vector< vector<entry> > function_args;
00071     
00072     unsigned total_args;
00073     
00074     vector<entry> outputs;
00075    
00076     Compiler() : total_args(0) {}
00077     
00078     entry do_add(Sym sym) {
00079 
00080         HashMap::iterator it = map.find(sym);
00081 
00082         if (it == map.end()) { // new entry
00083             
00084             token_t token = sym.token();
00085 
00086             if (is_constant(token)) {
00087                 constants.push_back( get_constant_value(token) ); // set value
00088                 entry e = make_pair(constant, constants.size()-1);
00089                 map.insert( make_pair(sym, e) );
00090                 return e;
00091                 
00092             } else if (is_variable(token)) {
00093                 unsigned idx = get_variable_index(token);
00094                 variables.push_back(idx);
00095                 entry e = make_pair(variable, variables.size()-1);
00096                 map.insert( make_pair(sym, e) );
00097                 return e;
00098             } // else 
00099                 
00100             fptr f;
00101             vector<entry> vec;
00102             const SymVec& args = sym.args();
00103             
00104             switch (token) {
00105                 case sum_token:
00106                     {
00107                         if (args.size() == 0) {
00108                             return do_add( SymConst(0.0));
00109                         }
00110                         if (args.size() == 1) {
00111                             return do_add(args[0]);
00112                         }
00113                         if (args.size() == 2) {
00114                             vec.push_back(do_add(args[0]));
00115                             vec.push_back(do_add(args[1]));
00116                             f = multi_function::plus;
00117                             //cout << "Adding + " << vec[0].second << ' ' << vec[1].second << endl;
00118                             break;
00119 
00120                         } else {
00121                             return do_add( make_binary(sym) );
00122                         }
00123                         
00124                     }
00125                 case prod_token:
00126                     {
00127                         if (args.size() == 0) {
00128                             return do_add( SymConst(1.0));
00129                         }
00130                         if (args.size() == 1) {
00131                             return do_add(args[0]);
00132                         }
00133                         if (args.size() == 2) {
00134                             vec.push_back(do_add(args[0]));
00135                             vec.push_back(do_add(args[1]));
00136                             f = multi_function::mult;
00137                             //cout << "Adding * " << vec[0].second << ' ' << vec[1].second << endl;
00138                             break;
00139                             
00140 
00141                         } else {
00142                             return do_add( make_binary(sym) );
00143                         }
00144                     }
00145                 case sqr_token: 
00146                     {
00147                         SymVec newargs(2);
00148                         newargs[0] = args[0];
00149                         newargs[1] = args[0];
00150                        return do_add( Sym(prod_token, newargs)); 
00151                     }
00152                 default :
00153                     {
00154                         if (args.size() != 1) {
00155                             cerr << "Unknown function " << sym << " encountered" << endl;
00156                             exit(1);
00157                         }
00158                         
00159                         vec.push_back(do_add(args[0]));
00160 
00161                         switch (token) {
00162                             case min_token: f = multi_function::min; break;
00163                             case inv_token: f = multi_function::inv; break;
00164                             case exp_token :f = multi_function::exp; break;
00165                             default :
00166                                 {
00167                                     cerr << "Unimplemented token encountered " << sym << endl;
00168                                     exit(1);
00169                                 }
00170                         }
00171                         
00172                         //cout << "Adding " << print_function(f) << ' ' << vec[0].second << endl;
00173                         
00174                         
00175                     }
00176 
00177             }
00178             
00179             total_args += vec.size();
00180             function_args.push_back(vec);
00181             functions.push_back(f);
00182             
00183             entry e = make_pair(function, functions.size()-1);
00184             map.insert( make_pair(sym, e) );
00185             return e;
00186             
00187         }
00188         
00189         return it->second; // entry
00190     }
00191    
00192     void add(Sym sym) {
00193         entry e = do_add(sym);
00194         outputs.push_back(e);
00195     }
00196     
00197 };
00198 
00199 class MultiFunctionImpl {
00200     public:
00201         
00202     // input mapping
00203     vector<unsigned> input_idx;
00204     
00205     unsigned constant_offset;
00206     unsigned var_offset;
00207 
00208     // evaluation
00209     vector<double> data;
00210     vector<Function> funcs;
00211     data_ptrs args;     
00212     
00213     vector<unsigned> output_idx;
00214     
00215     MultiFunctionImpl() {}
00216 
00217     void clear() {
00218         input_idx.clear();
00219         data.clear();
00220         funcs.clear();
00221         args.clear();
00222         output_idx.clear();
00223         constant_offset = 0;
00224     }
00225     
00226     void eval(const double* x, double* y) {
00227         unsigned i;
00228         // evaluate variables
00229         for (i = constant_offset; i < constant_offset + input_idx.size(); ++i) {
00230             data[i] = x[input_idx[i-constant_offset]];
00231         }
00232 
00233         for(; i < data.size(); ++i) {
00234             data[i] = funcs[i-var_offset]();
00235             //cout << i << " " << data[i] << endl;
00236         }
00237 
00238         for (i = 0; i < output_idx.size(); ++i) {
00239             y[i] = data[output_idx[i]];
00240         }
00241     }
00242 
00243     void eval(const vector<double>& x, vector<double>& y) {
00244         eval(&x[0], &y[0]);
00245     }
00246     
00247     void setup(const vector<Sym>& pop) {
00248         
00249         clear(); 
00250         Compiler compiler;
00251         
00252         for (unsigned i = 0; i < pop.size(); ++i) {
00253             Sym sym = (expand_all(pop[i]));
00254             compiler.add(sym);
00255         }
00256         
00257         // compiler is setup so get the data
00258         constant_offset = compiler.constants.size();
00259         var_offset = constant_offset + compiler.variables.size();
00260         int n = var_offset + compiler.functions.size();
00261 
00262         data.resize(n);
00263         funcs.resize(compiler.functions.size());
00264         args.resize(compiler.total_args);
00265         
00266         // constants
00267         for (unsigned i = 0; i < constant_offset; ++i) {
00268             data[i] = compiler.constants[i];
00269             //cout << i << ' ' << data[i] << endl;
00270         }
00271         
00272         // variables
00273         input_idx = compiler.variables;
00274 
00275         //for (unsigned i = constant_offset; i < var_offset; ++i) {
00276             //cout << i << " x" << input_idx[i-constant_offset] << endl;
00277         //}
00278         
00279         // functions
00280         unsigned which_arg = 0;
00281         for (unsigned i = 0; i < funcs.size(); ++i) {
00282             
00283             Function f;
00284             f.function = compiler.functions[i];
00285             
00286             //cout << i+var_offset << ' ' << print_function(f.function);
00287                 
00288             // interpret args
00289             for (unsigned j = 0; j < compiler.function_args[i].size(); ++j) {
00290                 
00291                 Compiler::entry e = compiler.function_args[i][j];
00292                 
00293                 unsigned idx = e.second;
00294                 
00295                 switch (e.first) {
00296                     case Compiler::function: idx += compiler.variables.size();
00297                     case Compiler::variable: idx += compiler.constants.size();
00298                     case Compiler::constant: {}
00299                 }
00300 
00301                 args[which_arg + j] = data.begin() + idx;
00302                 //cout << ' ' << idx << "(" << e.second << ")";
00303             }
00304             
00305             //cout << endl;
00306 
00307             f.args = args.begin() + which_arg;
00308             which_arg += compiler.function_args[i].size();
00309             funcs[i] = f;    
00310         }
00311 
00312         // output indices
00313         output_idx.resize(compiler.outputs.size());
00314         for (unsigned i = 0; i < output_idx.size(); ++i) {
00315             output_idx[i] = compiler.outputs[i].second;
00316             switch(compiler.outputs[i].first) {
00317                     case Compiler::function: output_idx[i] += compiler.variables.size();
00318                     case Compiler::variable: output_idx[i] += compiler.constants.size();
00319                     case Compiler::constant: {}
00320             }
00321             //cout << "out " << output_idx[i] << endl;
00322         }
00323     }
00324     
00325 };  
00326 
00327 
00328 
00329 MultiFunction::MultiFunction(const std::vector<Sym>& pop) : pimpl(new MultiFunctionImpl) {
00330     pimpl->setup(pop);
00331 }
00332 
00333 MultiFunction::~MultiFunction() { delete pimpl; }
00334 
00335 void MultiFunction::operator()(const std::vector<double>& x, std::vector<double>& y) {
00336     pimpl->eval(x,y);
00337 }
00338 
00339 void MultiFunction::operator()(const double* x, double* y) {
00340     pimpl->eval(x,y);
00341 }

Generated on Thu Oct 19 05:06:41 2006 for EO by  doxygen 1.3.9.1