sym_operations.cpp

00001 /*          
00002  *             Copyright (C) 2005 Maarten Keijzer
00003  *
00004  *          This program is free software; you can redistribute it and/or modify
00005  *          it under the terms of version 2 of the GNU General Public License as 
00006  *          published by the Free Software Foundation. 
00007  *
00008  *          This program is distributed in the hope that it will be useful,
00009  *          but WITHOUT ANY WARRANTY; without even the implied warranty of
00010  *          MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00011  *          GNU General Public License for more details.
00012  *
00013  *          You should have received a copy of the GNU General Public License
00014  *          along with this program; if not, write to the Free Software
00015  *          Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
00016  */
00017 
00018 #include <FunDef.h>
00019 
00020 using namespace std;
00021 
00022 Sym simplify_constants(Sym sym) {
00023 
00024     SymVec args = sym.args();
00025     token_t token = sym.token();
00026     
00027     bool has_changed = false;
00028     bool all_constants = true;
00029 
00030     for (unsigned i = 0; i < args.size(); ++i) {
00031         
00032         Sym arg = simplify_constants(args[i]);
00033         
00034         if (arg != args[i]) {
00035             has_changed = true;
00036         }
00037         args[i] = arg;
00038 
00039         all_constants &= is_constant(args[i].token());
00040     }
00041     
00042     if (args.size() == 0) {
00043         
00044         if (sym.token() == sum_token) return SymConst(0.0);
00045         if (sym.token() == prod_token) return SymConst(1.0);
00046         
00047         return sym; // variable or constant
00048     }
00049     
00050     if (all_constants) {
00051         // evaluate 
00052         
00053         vector<double> dummy;
00054         
00055         double v = ::eval(sym, dummy);
00056         
00057         Sym result = SymConst(v);
00058         
00059         return result;
00060     }
00061 
00062     if (has_changed) {
00063         return Sym(token, args);
00064     }
00065 
00066     return sym;
00067     
00068 }
00069 
00070 // currently only simplifies constants
00071 Sym simplify(Sym sym) {
00072     
00073     return simplify_constants(sym);
00074     
00075 }
00076 
00077 Sym derivative(token_t token, Sym x) {
00078     Sym one = Sym(prod_token);
00079     
00080     switch (token) {
00081         case inv_token : return Sym(inv_token, sqr(x));
00082         
00083         case sin_token : return -cos(x);
00084         case cos_token : return sin(x);
00085         case tan_token : return one + sqr(tan(x));
00086                          
00087         case asin_token : return inv( sqrt(one - sqr(x)));
00088         case acos_token:  return -inv( sqrt(one - sqr(x)));
00089         case atan_token : return inv( sqrt(one + sqr(x)));
00090         
00091         case cosh_token : return -sinh(x);
00092         case sinh_token : return cosh(x);
00093         case tanh_token : return one - sqr( tanh(x) );
00094         
00095         case asinh_token : return inv( sqrt( one + sqr(x) ));
00096         case acosh_token : return inv( sqrt(x-one) * sqrt(x + one)  );
00097         case atanh_token : return inv(one - sqr(x));
00098                          
00099         case exp_token : return exp(x);
00100         case log_token : return inv(x);
00101 
00102         case sqr_token : return SymConst(2.0) * x;
00103         case sqrt_token : return SymConst(0.5) * inv( sqrt(x));
00104         default :
00105             throw differentiation_error();
00106     }
00107     
00108     return x;
00109 }
00110 
00111 extern Sym differentiate(Sym sym, token_t dx) {
00112     
00113     token_t token = sym.token();
00114     
00115     Sym zero = Sym(sum_token);
00116     Sym one  = Sym(prod_token);
00117     
00118     if (token == dx) {
00119         return one;
00120     }
00121     
00122     SymVec args = sym.args();
00123 
00124     if (args.size() == 0) { // df/dx with f != x
00125         return zero;
00126     }
00127     
00128     switch (token) {
00129         
00130         case sum_token: 
00131             {
00132                 for (unsigned i = 0; i < args.size(); ++i) {
00133                     args[i] = differentiate(args[i], dx);
00134                 }
00135 
00136                 if (args.size() == 1) return args[0];
00137                 return Sym(sum_token, args);
00138             }
00139         case min_token : 
00140             {
00141                 return -differentiate(args[0],dx);
00142             }
00143         case prod_token: 
00144             {
00145                 if (args.size() == 1) return differentiate(args[0], dx);
00146                 
00147                 if (args.size() == 2) {
00148                     return args[0] * differentiate(args[1], dx) + args[1] * differentiate(args[0], dx);
00149                 }
00150                 // else 
00151                 Sym c = args.back();
00152                 args.pop_back();
00153                 Sym f = Sym(prod_token, args);
00154                 Sym df = differentiate( f, dx);
00155 
00156                 return c * df + f * differentiate(c,dx);
00157             }
00158         case pow_token : 
00159             {
00160                 return pow(args[0], args[1]) * args[1] * inv(args[0]);
00161             }
00162         case ifltz_token : 
00163             { // cannot be differentiated
00164                 throw differentiation_error(); // TODO define proper exception
00165             }
00166             
00167         default: // unary function: apply chain rule
00168             {
00169                 Sym arg = args[0];
00170                 return derivative(token, arg) * differentiate(arg, dx);
00171             }
00172     }
00173     
00174 }

Generated on Thu Oct 19 05:06:42 2006 for EO by  doxygen 1.3.9.1