diff --git a/eo/app/gprop/mlp.h b/eo/app/gprop/mlp.h index 3628325f1..06036e607 100644 --- a/eo/app/gprop/mlp.h +++ b/eo/app/gprop/mlp.h @@ -20,6 +20,10 @@ #include #include +#ifdef HAVE_LIBYAML_CPP +#include +#endif // HAVE_LIBYAML_CPP + namespace mlp { @@ -125,6 +129,36 @@ namespace mlp if ( probability >= 1.0 || rng.uniform() < probability) perturb_num(bias, magnitude); } + + #ifdef HAVE_LIBYAML_CPP + friend YAML::Emitter& operator<<(YAML::Emitter& out, const mlp::neuron& n) { + n.emit_yaml(out); + return out; + } + + void emit_yaml(YAML::Emitter&out) const { + out << YAML::BeginMap; + out << YAML::Key << "Class" << YAML::Value << "mlp::neuron"; + #define MY_EMIT_MEMBER(emitter,member) emitter << YAML::Key << #member << YAML::Value << this->member + MY_EMIT_MEMBER(out,bias); + MY_EMIT_MEMBER(out,weight); + out << YAML::EndMap; + #undef MY_EMIT_MEMBER + } + + friend void operator >>(const YAML::Node& node, mlp::neuron& n) { + n.load_yaml(node); + } + + void load_yaml(const YAML::Node& node) { + #define MY_LOAD_MEMBER(doc,member) doc[#member] >> member + MY_LOAD_MEMBER(node, bias); + MY_LOAD_MEMBER(node, weight); + #undef MY_LOAD_MEMBER + } + + + #endif }; } @@ -140,8 +174,10 @@ namespace std { return is >> n.bias >> n.weight; } + } + namespace mlp { //--------------------------------------------------------------------------- @@ -190,6 +226,21 @@ namespace mlp { for(iterator n = begin(); n != end(); ++n) n->perturb(); } + #ifdef HAVE_LIBYAML_CPP + friend ostream& operator<<(YAML::Emitter& e, const layer &l) { + e << ((std::vector)l); + } + + friend void operator>>(const YAML::Node& n, layer &l) { + // These temporary variable shenanegins are necessary because + // the compiler gets very confused about which template operator>> + // function to use. + // This does not work: n >> l; + // So we use a temporary variable thusly: + std::vector *obviously_a_vector = &l; + n >> *obviously_a_vector; + } + #endif }; @@ -235,6 +286,11 @@ namespace mlp { net(istream &is) { load(is); } + #ifdef HAVE_LIBYAML_CPP + net (YAML::Node &node) { + node >> *((std::vector*)this); + } + #endif /** Virtual destructor */ virtual ~net() {}; @@ -305,10 +361,13 @@ namespace mlp { for(const_iterator l = begin(); l != end(); ++l) os << l->size() << " "; os << "\n"; - os << *this; - os << "\n"; + os << "< "; + for(const_iterator l = begin(); l != end(); ++l) + os << *l << " "; + os << ">\n"; } + unsigned num_inputs() const { return front().front().length() - 1; } unsigned num_outputs() const { return back().size(); } unsigned num_hidden_layers() const {