#include #include "MultiFunction.h" #include "Sym.h" #include "FunDef.h" using namespace std; typedef vector::const_iterator data_ptr; typedef vector data_ptrs; typedef data_ptrs::const_iterator arg_ptr; #include "MultiFuncs.cpp" typedef double (*fptr)( arg_ptr ); string print_function( fptr f) { if (f == multi_function::plus) return "+"; if (f == multi_function::mult) return "*"; if (f == multi_function::min) return "-"; if (f == multi_function::inv) return "/"; if (f == multi_function::exp) return "e"; return "unknown"; } struct Function { fptr function; arg_ptr args; double operator()() const { return function(args); } }; static vector token_2_function; Sym make_binary(Sym sym) { if (sym.args().size() == 2) return sym; SymVec args = sym.args(); Sym an = args.back(); args.pop_back(); Sym nw = make_binary( Sym( sym.token(), args) ); args.resize(2); args[0] = nw; args[1] = an; return Sym(sym.token(), args); } class Compiler { public: enum func_type {constant, variable, function}; typedef pair entry; #if USE_TR1 typedef std::tr1::unordered_map HashMap; #else typedef hash_map HashMap; #endif HashMap map; vector constants; vector variables; vector< fptr > functions; vector< vector > function_args; unsigned total_args; vector outputs; Compiler() : total_args(0) {} entry do_add(Sym sym) { HashMap::iterator it = map.find(sym); if (it == map.end()) { // new entry token_t token = sym.token(); if (is_constant(token)) { constants.push_back( get_constant_value(token) ); // set value entry e = make_pair(constant, constants.size()-1); map.insert( make_pair(sym, e) ); return e; } else if (is_variable(token)) { unsigned idx = get_variable_index(token); variables.push_back(idx); entry e = make_pair(variable, variables.size()-1); map.insert( make_pair(sym, e) ); return e; } // else fptr f; vector vec; const SymVec& args = sym.args(); switch (token) { case sum_token: { if (args.size() == 0) { return do_add( SymConst(0.0)); } if (args.size() == 1) { return do_add(args[0]); } if (args.size() == 2) { vec.push_back(do_add(args[0])); vec.push_back(do_add(args[1])); f = multi_function::plus; //cout << "Adding + " << vec[0].second << ' ' << vec[1].second << endl; break; } else { return do_add( make_binary(sym) ); } } case prod_token: { if (args.size() == 0) { return do_add( SymConst(1.0)); } if (args.size() == 1) { return do_add(args[0]); } if (args.size() == 2) { vec.push_back(do_add(args[0])); vec.push_back(do_add(args[1])); f = multi_function::mult; //cout << "Adding * " << vec[0].second << ' ' << vec[1].second << endl; break; } else { return do_add( make_binary(sym) ); } } case sqr_token: { SymVec newargs(2); newargs[0] = args[0]; newargs[1] = args[0]; return do_add( Sym(prod_token, newargs)); } default : { if (args.size() != 1) { cerr << "Unknown function " << sym << " encountered" << endl; exit(1); } vec.push_back(do_add(args[0])); switch (token) { case min_token: f = multi_function::min; break; case inv_token: f = multi_function::inv; break; case exp_token :f = multi_function::exp; break; default : { cerr << "Unimplemented token encountered " << sym << endl; exit(1); } } //cout << "Adding " << print_function(f) << ' ' << vec[0].second << endl; } } total_args += vec.size(); function_args.push_back(vec); functions.push_back(f); entry e = make_pair(function, functions.size()-1); map.insert( make_pair(sym, e) ); return e; } return it->second; // entry } void add(Sym sym) { entry e = do_add(sym); outputs.push_back(e); } }; class MultiFunctionImpl { public: // input mapping vector input_idx; unsigned constant_offset; unsigned var_offset; // evaluation vector data; vector funcs; data_ptrs args; vector output_idx; MultiFunctionImpl() {} void clear() { input_idx.clear(); data.clear(); funcs.clear(); args.clear(); output_idx.clear(); constant_offset = 0; } void eval(const double* x, double* y) { unsigned i; // evaluate variables for (i = constant_offset; i < constant_offset + input_idx.size(); ++i) { data[i] = x[input_idx[i-constant_offset]]; } for(; i < data.size(); ++i) { data[i] = funcs[i-var_offset](); //cout << i << " " << data[i] << endl; } for (i = 0; i < output_idx.size(); ++i) { y[i] = data[output_idx[i]]; } } void eval(const vector& x, vector& y) { eval(&x[0], &y[0]); } void setup(const vector& pop) { clear(); Compiler compiler; for (unsigned i = 0; i < pop.size(); ++i) { Sym sym = (expand_all(pop[i])); compiler.add(sym); } // compiler is setup so get the data constant_offset = compiler.constants.size(); var_offset = constant_offset + compiler.variables.size(); int n = var_offset + compiler.functions.size(); data.resize(n); funcs.resize(compiler.functions.size()); args.resize(compiler.total_args); // constants for (unsigned i = 0; i < constant_offset; ++i) { data[i] = compiler.constants[i]; //cout << i << ' ' << data[i] << endl; } // variables input_idx = compiler.variables; //for (unsigned i = constant_offset; i < var_offset; ++i) { //cout << i << " x" << input_idx[i-constant_offset] << endl; //} // functions unsigned which_arg = 0; for (unsigned i = 0; i < funcs.size(); ++i) { Function f; f.function = compiler.functions[i]; //cout << i+var_offset << ' ' << print_function(f.function); // interpret args for (unsigned j = 0; j < compiler.function_args[i].size(); ++j) { Compiler::entry e = compiler.function_args[i][j]; unsigned idx = e.second; switch (e.first) { case Compiler::function: idx += compiler.variables.size(); case Compiler::variable: idx += compiler.constants.size(); case Compiler::constant: {} } args[which_arg + j] = data.begin() + idx; //cout << ' ' << idx << "(" << e.second << ")"; } //cout << endl; f.args = args.begin() + which_arg; which_arg += compiler.function_args[i].size(); funcs[i] = f; } // output indices output_idx.resize(compiler.outputs.size()); for (unsigned i = 0; i < output_idx.size(); ++i) { output_idx[i] = compiler.outputs[i].second; switch(compiler.outputs[i].first) { case Compiler::function: output_idx[i] += compiler.variables.size(); case Compiler::variable: output_idx[i] += compiler.constants.size(); case Compiler::constant: {} } //cout << "out " << output_idx[i] << endl; } } }; MultiFunction::MultiFunction(const std::vector& pop) : pimpl(new MultiFunctionImpl) { pimpl->setup(pop); } MultiFunction::~MultiFunction() { delete pimpl; } void MultiFunction::operator()(const std::vector& x, std::vector& y) { pimpl->eval(x,y); } void MultiFunction::operator()(const double* x, double* y) { pimpl->eval(x,y); }