VNode.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. /*
  2. * VNode.cpp
  3. *
  4. * Created on: 8 Apr 2013
  5. * Author: s0965328
  6. */
  7. #include "VNode.h"
  8. #include "Tape.h"
  9. #include "Stack.h"
  10. using namespace std;
  11. namespace AutoDiff {
  12. #if FORWARD_ENABLED
  13. int VNode::DEFAULT_ID = -1;
  14. #endif
  15. VNode::VNode(double v) : ActNode(), val(v),u(NaN_Double)
  16. #if FORWARD_ENABLED
  17. ,id(DEFAULT_ID)
  18. #endif
  19. {
  20. }
  21. VNode::~VNode() {
  22. }
  23. void VNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total)
  24. {
  25. total++;
  26. boost::unordered_set<Node*>::iterator it = nodes.find(this);
  27. if(it==nodes.end())
  28. nodes.insert(this);
  29. }
  30. void VNode::inorder_visit(int level,ostream& oss)
  31. {
  32. oss<<this->toString(level)<<endl;
  33. }
  34. void VNode::eval_function()
  35. {
  36. SV->push_back(val);
  37. }
  38. string VNode::toString(int level)
  39. {
  40. ostringstream oss;
  41. string s(level,'\t');
  42. oss<<s<<"[VNode](index:"<<index<<",val:"<<val<<",u:"<<u<<") - "<<this;
  43. return oss.str();
  44. }
  45. void VNode::grad_reverse_0()
  46. {
  47. this->adj = 0;
  48. SV->push_back(val);
  49. }
  50. void VNode::grad_reverse_1()
  51. {
  52. //do nothing
  53. //this is a leaf node
  54. }
  55. #if FORWARD_ENABLED
  56. void VNode::hess_forward(unsigned int len, double** ret_vec)
  57. {
  58. assert(id!=DEFAULT_ID);
  59. (*ret_vec) = new double[len];
  60. std::fill_n(*ret_vec,len,0);
  61. (*ret_vec)[id]=1;
  62. SV->push_back(this->val);
  63. }
  64. #endif
  65. unsigned int VNode::hess_reverse_0()
  66. {
  67. if(index==0)
  68. {//this node is not on tape
  69. double nan = NaN_Double;
  70. TT->set(val); //x_i
  71. TT->set(nan); //x_bar_i
  72. TT->set(u); //w_i
  73. TT->set(nan); //w_bar_i
  74. index = TT->index;
  75. }
  76. // cout<<toString(0)<<" -- "<<index<<endl;
  77. return index;
  78. }
  79. void VNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
  80. {
  81. w_bar = TT->get(--i);
  82. w = TT->get(--i);
  83. x_bar = TT->get(--i);
  84. x = TT->get(--i);
  85. }
  86. void VNode::hess_reverse_1(unsigned int i)
  87. {
  88. n_in_arcs--;
  89. //leaf node do nothing
  90. }
  91. void VNode::hess_reverse_1_init_x_bar(unsigned int i)
  92. {
  93. TT->at(i-3) = 1;
  94. }
  95. void VNode::update_x_bar(unsigned int i ,double v)
  96. {
  97. // cout<<toString(0)<<" --- "<<__FUNCTION__<<" v="<<TT->at(i-3)<<"+"<<v<<endl;
  98. TT->at(i-3) = isnan(TT->get(i-3))? v: TT->get(i-3) + v;
  99. }
  100. void VNode::update_w_bar(unsigned int i,double v)
  101. {
  102. // cout<<toString(0)<<" --- "<<__FUNCTION__<<" v="<<TT->at(i-1)<<"+"<<v<<endl;
  103. TT->at(i-1) = isnan(TT->get(i-1))? v: TT->get(i-1) + v;
  104. }
  105. void VNode::hess_reverse_1_get_xw(unsigned int i, double& w,double& x)
  106. {
  107. //cout<<toString(0)<<" --- "<<__FUNCTION__<<" w="<<TT->get(i-2)<<"-- "<<"x="<<TT->get(i-4)<<endl;
  108. w = TT->get(i-2);
  109. x = TT->get(i-4);
  110. }
  111. void VNode::hess_reverse_get_x(unsigned int i ,double& x)
  112. {
  113. x = TT->get(i-4);
  114. }
  115. void VNode::nonlinearEdges(EdgeSet& edges)
  116. {
  117. // for(list<Edge>::iterator it = edges.edges.begin();it!=edges.edges.end();)
  118. // {
  119. // Edge e=*it;
  120. //
  121. // }
  122. }
  123. TYPE VNode::getType()
  124. {
  125. return VNode_Type;
  126. }
  127. }