sym_compile.cpp

00001 /*          
00002  *             Copyright (C) 2005 Maarten Keijzer
00003  *
00004  *          This program is free software; you can redistribute it and/or modify
00005  *          it under the terms of version 2 of the GNU General Public License as 
00006  *          published by the Free Software Foundation. 
00007  *
00008  *          This program is distributed in the hope that it will be useful,
00009  *          but WITHOUT ANY WARRANTY; without even the implied warranty of
00010  *          MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00011  *          GNU General Public License for more details.
00012  *
00013  *          You should have received a copy of the GNU General Public License
00014  *          along with this program; if not, write to the Free Software
00015  *          Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
00016  */
00017 
00018 
00019 #include "Sym.h"
00020 #include "FunDef.h"
00021 #include "sym_compile.h"
00022 
00023 #include <sstream>
00024 
00025 using namespace std;
00026 
00027 extern "C" {
00028     void  symc_init();
00029     int  symc_compile(const char* func_str);
00030     int  symc_link();
00031     void* symc_get_fun(const char* func_name);
00032     void* symc_make(const char* func_str, const char* func_name);
00033 }
00034 
00035 string make_prototypes() {
00036     string prot = get_prototypes();
00037     prot += "double sqr(double x) { return x*x; }";
00038     return prot;
00039 }
00040 
00041 // contains variable names, like 'a0', 'a1', etc. or regular code
00042 
00043 #if USE_TR1 
00044 typedef std::tr1::unordered_map<Sym, string, HashSym> HashMap;
00045 #else
00046 typedef hash_map<Sym, string, HashSym> HashMap;
00047 #endif
00048 
00049 // prints 'num' in reverse notation. Does not matter as it's a unique id
00050 string make_var(unsigned num) {
00051     string str = "a";
00052     do {
00053         str += char('0' + (num % 10));
00054         num /= 10;
00055     } while (num);
00056     return str;
00057 }
00058 
00059 template <class T>
00060 string to_string(T t) {
00061     ostringstream os;
00062     os << t;
00063     return os.str();
00064 }
00065 
00066 
00067 HashMap::iterator find_entry(const Sym& sym, string& str, HashMap& map) {
00068     HashMap::iterator result = map.find(sym);
00069 
00070     if (result == map.end()) { // new entry
00071         const SymVec& args = sym.args();
00072         
00073         vector<string> argstr(args.size());
00074         for (unsigned i = 0; i < args.size(); ++i) {
00075             argstr[i] = find_entry(args[i], str, map)->second;
00076         }
00077 
00078         string var = make_var(map.size()); // map.size(): unique id
00079         string code;    
00080         // write out the code
00081         const FunDef& fun = get_element(sym.token());
00082         code = fun.c_print(argstr, vector<string>() );
00083             
00084         str += "double " + var + "=" + code + ";\n";
00085         result = map.insert( make_pair(sym, var ) ).first; // only want iterator
00086     }
00087     
00088     return result;
00089 }
00090 
00091 void write_entry(const Sym& sym, string& str, HashMap& map, unsigned out) {
00092     HashMap::iterator it = find_entry(sym, str, map);
00093     
00094     str += "y[" + to_string(out) + "]=" + it->second + ";\n";
00095     //cout << "wrote " << out << '\n';
00096 }
00097 
00098 #include <fstream>
00099 multi_function compile(const std::vector<Sym>& syms) {
00100     
00101     //cout << "Multifunction " << syms.size() << endl;
00102     // static stream to avoid fragmentation of these LARGE strings
00103     static string str;
00104     str.clear();
00105     str += make_prototypes();
00106 
00107     str += "extern double func(const double* x, double* y) { \n ";
00108    
00109     multi_function result;
00110     HashMap map(Sym::get_dag().size());
00111     
00112     for (unsigned i = 0; i < syms.size(); ++i) {
00113         write_entry(syms[i], str, map, i);
00114     }
00115     
00116     str += ";}";
00117 
00118     
00119     /*static int counter = 0;
00120     ostringstream nm;
00121     nm << "cmp/compiled" << (counter++) << ".c";
00122     cout << "Saving as " << nm.str() << endl;
00123     ofstream cmp(nm.str().c_str());
00124     cmp << str;
00125     cmp.close();
00126 
00127     //cout << "Multifunction " << syms.size() << endl;
00128     cout << "Size of map " << map.size() << endl;
00129 */
00130 
00131     result = (multi_function) symc_make(str.c_str(), "func"); 
00132 
00133     if (result==0) { // error
00134         cout << "Error in compile " << endl;
00135     }
00136 
00137     return result;
00138 }
00139 
00140 single_function compile(Sym sym) {
00141 
00142     ostringstream os;
00143 
00144     os << make_prototypes();
00145     os << "double func(const double* x) { return ";
00146     
00147     string code = c_print(sym);
00148     os << code;
00149     os << ";}";
00150     string func_str = os.str();
00151   
00152     //cout << "compiling " << func_str << endl;
00153     
00154     return  (single_function) symc_make(func_str.c_str(), "func"); 
00155 }
00156 
00157 /* finds and inserts the full code in a hashmap */
00158 HashMap::iterator find_code(Sym sym, HashMap& map) {
00159     HashMap::iterator result = map.find(sym);
00160 
00161     if (result == map.end()) { // new entry
00162         const SymVec& args = sym.args();
00163         vector<string> argstr(args.size());
00164         for (unsigned i = 0; i < args.size(); ++i) {
00165             argstr[i] = find_code(args[i], map)->second;
00166         }
00167 
00168         // write out the code
00169         const FunDef& fun = get_element(sym.token());
00170         string code = fun.c_print(argstr, vector<string>());
00171         result = map.insert( make_pair(sym, code) ).first; // only want iterator
00172     }
00173     
00174     return result;
00175 }
00176 
00177 string print_code(Sym sym, HashMap& map) {
00178     HashMap::iterator it = find_code(sym, map);
00179     return it->second;
00180 }
00181 
00182 void compile(const std::vector<Sym>& syms, std::vector<single_function>& functions) {
00183     symc_init();
00184     
00185     static ostringstream os;
00186     os.str("");
00187     
00188     os << make_prototypes();
00189     HashMap map(Sym::get_dag().size());
00190     for (unsigned i = 0; i < syms.size(); ++i) {
00191         
00192         os << "double func" << i << "(const double* x) { return ";
00193         os << print_code(syms[i], map); //c_print(syms[i]);
00194         os << ";}\n";
00195 
00196         //symc_compile(os.str().c_str());
00197         //cout << "compiling " << os.str() << endl;     
00198     }
00199 
00200     os << ends;
00201 #ifdef INTERVAL_DEBUG
00202     //cout << "Compiling " << os.str() << endl;
00203 #endif
00204     
00205     symc_compile(os.str().c_str()); 
00206     symc_link();
00207 
00208     functions.resize(syms.size());
00209     for (unsigned i = 0; i < syms.size(); ++i) {
00210         ostringstream os2;
00211         os2 << "func" << i;
00212         
00213         functions[i] = (single_function) symc_get_fun(os2.str().c_str());
00214     }
00215 
00216 }
00217 
00218 
00219 

Generated on Thu Oct 19 05:06:42 2006 for EO by  doxygen 1.3.9.1