BinaryOPNode.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. /*
  2. * BinaryOPNode.cpp
  3. *
  4. * Created on: 6 Nov 2013
  5. * Author: s0965328
  6. */
  7. #include "auto_diff_types.h"
  8. #include "BinaryOPNode.h"
  9. #include "PNode.h"
  10. #include "Stack.h"
  11. #include "Tape.h"
  12. #include "EdgeSet.h"
  13. #include "Node.h"
  14. #include "VNode.h"
  15. #include "OPNode.h"
  16. #include "ActNode.h"
  17. #include "EdgeSet.h"
  18. namespace AutoDiff {
  19. BinaryOPNode::BinaryOPNode(OPCODE op_, Node* left_, Node* right_):OPNode(op_,left_),right(right_)
  20. {
  21. }
  22. OPNode* BinaryOPNode::createBinaryOpNode(OPCODE op, Node* left, Node* right)
  23. {
  24. assert(left!=NULL && right!=NULL);
  25. OPNode* node = NULL;
  26. node = new BinaryOPNode(op,left,right);
  27. return node;
  28. }
  29. BinaryOPNode::~BinaryOPNode() {
  30. if(right->getType()!=VNode_Type)
  31. {
  32. delete right;
  33. right = NULL;
  34. }
  35. }
  36. void BinaryOPNode::inorder_visit(int level,ostream& oss){
  37. if(left!=NULL){
  38. left->inorder_visit(level+1,oss);
  39. }
  40. oss<<this->toString(level)<<endl;
  41. if(right!=NULL){
  42. right->inorder_visit(level+1,oss);
  43. }
  44. }
  45. void BinaryOPNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total){
  46. total++;
  47. if (left != NULL) {
  48. left->collect_vnodes(nodes,total);
  49. }
  50. if (right != NULL) {
  51. right->collect_vnodes(nodes,total);
  52. }
  53. }
  54. void BinaryOPNode::eval_function()
  55. {
  56. assert(left!=NULL && right!=NULL);
  57. left->eval_function();
  58. right->eval_function();
  59. this->calc_eval_function();
  60. }
  61. void BinaryOPNode::calc_eval_function()
  62. {
  63. double x = NaN_Double;
  64. double rx = SV->pop_back();
  65. double lx = SV->pop_back();
  66. switch (op)
  67. {
  68. case OP_PLUS:
  69. x = lx + rx;
  70. break;
  71. case OP_MINUS:
  72. x = lx - rx;
  73. break;
  74. case OP_TIMES:
  75. x = lx * rx;
  76. break;
  77. case OP_DIVID:
  78. x = lx / rx;
  79. break;
  80. case OP_POW:
  81. x = pow(lx,rx);
  82. break;
  83. default:
  84. cerr<<"op["<<op<<"] not yet implemented!!"<<endl;
  85. assert(false);
  86. break;
  87. }
  88. SV->push_back(x);
  89. }
  90. //1. visiting left if not NULL
  91. //2. then, visiting right if not NULL
  92. //3. calculating the immediate derivative hu and hv
  93. void BinaryOPNode::grad_reverse_0()
  94. {
  95. assert(left!=NULL && right != NULL);
  96. this->adj = 0;
  97. left->grad_reverse_0();
  98. right->grad_reverse_0();
  99. this->calc_grad_reverse_0();
  100. }
  101. //right left - right most traversal
  102. void BinaryOPNode::grad_reverse_1()
  103. {
  104. assert(right!=NULL && left!=NULL);
  105. double r_adj = SD->pop_back()*this->adj;
  106. right->update_adj(r_adj);
  107. double l_adj = SD->pop_back()*this->adj;
  108. left->update_adj(l_adj);
  109. right->grad_reverse_1();
  110. left->grad_reverse_1();
  111. }
  112. void BinaryOPNode::calc_grad_reverse_0()
  113. {
  114. assert(left!=NULL && right != NULL);
  115. double l_dh = NaN_Double;
  116. double r_dh = NaN_Double;
  117. double rx = SV->pop_back();
  118. double lx = SV->pop_back();
  119. double x = NaN_Double;
  120. switch (op)
  121. {
  122. case OP_PLUS:
  123. x = lx + rx;
  124. l_dh = 1;
  125. r_dh = 1;
  126. break;
  127. case OP_MINUS:
  128. x = lx - rx;
  129. l_dh = 1;
  130. r_dh = -1;
  131. break;
  132. case OP_TIMES:
  133. x = lx * rx;
  134. l_dh = rx;
  135. r_dh = lx;
  136. break;
  137. case OP_DIVID:
  138. x = lx / rx;
  139. l_dh = 1 / rx;
  140. r_dh = -(lx) / pow(rx, 2);
  141. break;
  142. case OP_POW:
  143. if(right->getType()==PNode_Type){
  144. x = pow(lx,rx);
  145. l_dh = rx*pow(lx,(rx-1));
  146. r_dh = 0;
  147. }
  148. else{
  149. assert(lx>0.0); //otherwise log(lx) is not defined in read number
  150. x = pow(lx,rx);
  151. l_dh = rx*pow(lx,(rx-1));
  152. r_dh = pow(lx,rx)*log(lx); //this is for x1^x2 when x1=0 cause r_dh become +inf, however d(0^x2)/d(x2) = 0
  153. }
  154. break;
  155. default:
  156. cerr<<"error op not impl"<<endl;
  157. break;
  158. }
  159. SV->push_back(x);
  160. SD->push_back(l_dh);
  161. SD->push_back(r_dh);
  162. }
  163. void BinaryOPNode::hess_reverse_0_init_n_in_arcs()
  164. {
  165. this->left->hess_reverse_0_init_n_in_arcs();
  166. this->right->hess_reverse_0_init_n_in_arcs();
  167. this->Node::hess_reverse_0_init_n_in_arcs();
  168. }
  169. void BinaryOPNode::hess_reverse_1_clear_index()
  170. {
  171. this->left->hess_reverse_1_clear_index();
  172. this->right->hess_reverse_1_clear_index();
  173. this->Node::hess_reverse_1_clear_index();
  174. }
  175. unsigned int BinaryOPNode::hess_reverse_0()
  176. {
  177. assert(this->left!=NULL && right!=NULL);
  178. if(index==0)
  179. {
  180. unsigned int lindex=0, rindex=0;
  181. lindex = left->hess_reverse_0();
  182. rindex = right->hess_reverse_0();
  183. assert(lindex!=0 && rindex !=0);
  184. II->set(lindex);
  185. II->set(rindex);
  186. double rx,rx_bar,rw,rw_bar;
  187. double lx,lx_bar,lw,lw_bar;
  188. double x,x_bar,w,w_bar;
  189. double r_dh, l_dh;
  190. right->hess_reverse_0_get_values(rindex,rx,rx_bar,rw,rw_bar);
  191. left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
  192. switch(op)
  193. {
  194. case OP_PLUS:
  195. // cout<<"lindex="<<lindex<<"\trindex="<<rindex<<"\tI="<<I<<endl;
  196. x = lx + rx;
  197. // cout<<lx<<"\t+"<<rx<<"\t="<<x<<"\t\t"<<toString(0)<<endl;
  198. x_bar = 0;
  199. l_dh = 1;
  200. r_dh = 1;
  201. w = lw * l_dh + rw * r_dh;
  202. // cout<<lw<<"\t+"<<rw<<"\t="<<w<<"\t\t"<<toString(0)<<endl;
  203. w_bar = 0;
  204. break;
  205. case OP_MINUS:
  206. x = lx - rx;
  207. x_bar = 0;
  208. l_dh = 1;
  209. r_dh = -1;
  210. w = lw * l_dh + rw * r_dh;
  211. w_bar = 0;
  212. break;
  213. case OP_TIMES:
  214. x = lx * rx;
  215. x_bar = 0;
  216. l_dh = rx;
  217. r_dh = lx;
  218. w = lw * l_dh + rw * r_dh;
  219. w_bar = 0;
  220. break;
  221. case OP_DIVID:
  222. x = lx / rx;
  223. x_bar = 0;
  224. l_dh = 1/rx;
  225. r_dh = -lx/pow(rx,2);
  226. w = lw * l_dh + rw * r_dh;
  227. w_bar = 0;
  228. break;
  229. case OP_POW:
  230. if(right->getType()==PNode_Type)
  231. {
  232. x = pow(lx,rx);
  233. x_bar = 0;
  234. l_dh = rx*pow(lx,(rx-1));
  235. r_dh = 0;
  236. w = lw * l_dh + rw * r_dh;
  237. w_bar = 0;
  238. }
  239. else
  240. {
  241. assert(lx>0.0); //otherwise log(lx) undefined in real number
  242. x = pow(lx,rx);
  243. x_bar = 0;
  244. l_dh = rx*pow(lx,(rx-1));
  245. r_dh = pow(lx,rx)*log(lx); //log(lx) cause -inf when lx=0;
  246. w = lw * l_dh + rw * r_dh;
  247. w_bar = 0;
  248. }
  249. break;
  250. default:
  251. cerr<<"op["<<op<<"] not yet implemented!"<<endl;
  252. assert(false);
  253. break;
  254. }
  255. TT->set(x);
  256. TT->set(x_bar);
  257. TT->set(w);
  258. TT->set(w_bar);
  259. TT->set(l_dh);
  260. TT->set(r_dh);
  261. assert(TT->index == TT->index);
  262. index = TT->index;
  263. }
  264. return index;
  265. }
  266. void BinaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
  267. {
  268. --i; // skip the r_dh (ie, dh/du)
  269. --i; // skip the l_dh (ie. dh/dv)
  270. w_bar = TT->get(--i);
  271. w = TT->get(--i);
  272. x_bar = TT->get(--i);
  273. x = TT->get(--i);
  274. }
  275. void BinaryOPNode::hess_reverse_1(unsigned int i)
  276. {
  277. n_in_arcs--;
  278. if(n_in_arcs==0)
  279. {
  280. assert(right!=NULL && left!=NULL);
  281. unsigned int rindex = II->get(--(II->index));
  282. unsigned int lindex = II->get(--(II->index));
  283. // cout<<"ri["<<rindex<<"]\tli["<<lindex<<"]\t"<<this->toString(0)<<endl;
  284. double r_dh = TT->get(--i);
  285. double l_dh = TT->get(--i);
  286. double w_bar = TT->get(--i);
  287. --i; //skip w
  288. double x_bar = TT->get(--i);
  289. --i; //skip x
  290. double lw_bar=0,rw_bar=0;
  291. double lw=0,lx=0; left->hess_reverse_1_get_xw(lindex,lw,lx);
  292. double rw=0,rx=0; right->hess_reverse_1_get_xw(rindex,rw,rx);
  293. switch(op)
  294. {
  295. case OP_PLUS:
  296. assert(l_dh==1);
  297. assert(r_dh==1);
  298. lw_bar += w_bar*l_dh;
  299. rw_bar += w_bar*r_dh;
  300. break;
  301. case OP_MINUS:
  302. assert(l_dh==1);
  303. assert(r_dh==-1);
  304. lw_bar += w_bar*l_dh;
  305. rw_bar += w_bar*r_dh;
  306. break;
  307. case OP_TIMES:
  308. assert(rx == l_dh);
  309. assert(lx == r_dh);
  310. lw_bar += w_bar*rx;
  311. lw_bar += x_bar*lw*0 + x_bar*rw*1;
  312. rw_bar += w_bar*lx;
  313. rw_bar += x_bar*lw*1 + x_bar*rw*0;
  314. break;
  315. case OP_DIVID:
  316. lw_bar += w_bar*l_dh;
  317. lw_bar += x_bar*lw*0 + x_bar*rw*-1/(pow(rx,2));
  318. rw_bar += w_bar*r_dh;
  319. rw_bar += x_bar*lw*-1/pow(rx,2) + x_bar*rw*2*lx/pow(rx,3);
  320. break;
  321. case OP_POW:
  322. if(right->getType()==PNode_Type){
  323. lw_bar += w_bar*l_dh;
  324. lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + 0;
  325. rw_bar += w_bar*r_dh; assert(r_dh==0.0);
  326. rw_bar += 0;
  327. }
  328. else{
  329. assert(lx>0.0); //otherwise log(lx) is not define in Real
  330. lw_bar += w_bar*l_dh;
  331. lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + x_bar*rw*pow(lx,rx-1)*(rx*log(lx)+1); //cause log(lx)=-inf when
  332. rw_bar += w_bar*r_dh;
  333. rw_bar += x_bar*lw*pow(lx,rx-1)*(rx*log(lx)+1) + x_bar*rw*pow(lx,rx)*pow(log(lx),2);
  334. }
  335. break;
  336. default:
  337. cerr<<"op["<<op<<"] not yet implemented !"<<endl;
  338. assert(false);
  339. break;
  340. }
  341. double rx_bar = x_bar*r_dh;
  342. double lx_bar = x_bar*l_dh;
  343. right->update_x_bar(rindex,rx_bar);
  344. left->update_x_bar(lindex,lx_bar);
  345. right->update_w_bar(rindex,rw_bar);
  346. left->update_w_bar(lindex,lw_bar);
  347. this->right->hess_reverse_1(rindex);
  348. this->left->hess_reverse_1(lindex);
  349. }
  350. }
  351. void BinaryOPNode::hess_reverse_1_init_x_bar(unsigned int i)
  352. {
  353. TT->at(i-5) = 1;
  354. }
  355. void BinaryOPNode::update_x_bar(unsigned int i ,double v)
  356. {
  357. TT->at(i-5) += v;
  358. }
  359. void BinaryOPNode::update_w_bar(unsigned int i ,double v)
  360. {
  361. TT->at(i-3) += v;
  362. }
  363. void BinaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x)
  364. {
  365. w = TT->get(i-4);
  366. x = TT->get(i-6);
  367. }
  368. void BinaryOPNode::hess_reverse_get_x(unsigned int i,double& x)
  369. {
  370. x = TT->get(i-6);
  371. }
  372. void BinaryOPNode::nonlinearEdges(EdgeSet& edges)
  373. {
  374. for(list<Edge>::iterator it=edges.edges.begin();it!=edges.edges.end();)
  375. {
  376. Edge e = *it;
  377. if(e.a==this || e.b == this){
  378. if(e.a == this && e.b == this)
  379. {
  380. Edge e1(left,left);
  381. Edge e2(right,right);
  382. Edge e3(left,right);
  383. edges.insertEdge(e1);
  384. edges.insertEdge(e2);
  385. edges.insertEdge(e3);
  386. }
  387. else
  388. {
  389. Node* o = e.a==this? e.b: e.a;
  390. Edge e1(left,o);
  391. Edge e2(right,o);
  392. edges.insertEdge(e1);
  393. edges.insertEdge(e2);
  394. }
  395. it = edges.edges.erase(it);
  396. }
  397. else
  398. {
  399. it++;
  400. }
  401. }
  402. Edge e1(left,right);
  403. Edge e2(left,left);
  404. Edge e3(right,right);
  405. switch(op)
  406. {
  407. case OP_PLUS:
  408. case OP_MINUS:
  409. //do nothing for linear operator
  410. break;
  411. case OP_TIMES:
  412. edges.insertEdge(e1);
  413. break;
  414. case OP_DIVID:
  415. edges.insertEdge(e1);
  416. edges.insertEdge(e3);
  417. break;
  418. case OP_POW:
  419. edges.insertEdge(e1);
  420. edges.insertEdge(e2);
  421. edges.insertEdge(e3);
  422. break;
  423. default:
  424. cerr<<"op["<<op<<"] not yet implmented !"<<endl;
  425. assert(false);
  426. break;
  427. }
  428. left->nonlinearEdges(edges);
  429. right->nonlinearEdges(edges);
  430. }
  431. #if FORWARD_ENABLED
  432. void BinaryOPNode::hess_forward(unsigned int len, double** ret_vec)
  433. {
  434. double* lvec = NULL;
  435. double* rvec = NULL;
  436. if(left!=NULL){
  437. left->hess_forward(len,&lvec);
  438. }
  439. if(right!=NULL){
  440. right->hess_forward(len,&rvec);
  441. }
  442. *ret_vec = new double[len];
  443. hess_forward_calc0(len,lvec,rvec,*ret_vec);
  444. //delete lvec, rvec
  445. delete[] lvec;
  446. delete[] rvec;
  447. }
  448. void BinaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* rvec, double* ret_vec)
  449. {
  450. double hu = NaN_Double, hv= NaN_Double;
  451. double lval = NaN_Double, rval = NaN_Double;
  452. double val = NaN_Double;
  453. unsigned int index = 0;
  454. switch (op)
  455. {
  456. case OP_PLUS:
  457. rval = SV->pop_back();
  458. lval = SV->pop_back();
  459. val = lval + rval;
  460. SV->push_back(val);
  461. //calculate the first order derivatives
  462. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  463. {
  464. ret_vec[i] = lvec[i]+rvec[i];
  465. }
  466. //calculate the second order
  467. index = AutoDiff::num_var;
  468. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  469. {
  470. for(unsigned int j=i;j<AutoDiff::num_var;++j){
  471. ret_vec[index] = lvec[index] + 0 + rvec[index] + 0;
  472. ++index;
  473. }
  474. }
  475. assert(index==len);
  476. break;
  477. case OP_MINUS:
  478. rval = SV->pop_back();
  479. lval = SV->pop_back();
  480. val = lval + rval;
  481. SV->push_back(val);
  482. //calculate the first order derivatives
  483. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  484. {
  485. ret_vec[i] = lvec[i] - rvec[i];
  486. }
  487. //calculate the second order
  488. index = AutoDiff::num_var;
  489. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  490. {
  491. for(unsigned int j=i;j<AutoDiff::num_var;++j){
  492. ret_vec[index] = lvec[index] + 0 - rvec[index] + 0;
  493. ++index;
  494. }
  495. }
  496. assert(index==len);
  497. break;
  498. case OP_TIMES:
  499. rval = SV->pop_back();
  500. lval = SV->pop_back();
  501. val = lval * rval;
  502. SV->push_back(val);
  503. hu = rval;
  504. hv = lval;
  505. //calculate the first order derivatives
  506. for(unsigned int i =0;i<AutoDiff::num_var;++i)
  507. {
  508. ret_vec[i] = hu*lvec[i] + hv*rvec[i];
  509. }
  510. //calculate the second order
  511. index = AutoDiff::num_var;
  512. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  513. {
  514. for(unsigned int j=i;j<AutoDiff::num_var;++j)
  515. {
  516. ret_vec[index] = hu * lvec[index] + lvec[i] * rvec[j]+hv * rvec[index] + rvec[i] * lvec[j];
  517. ++index;
  518. }
  519. }
  520. assert(index==len);
  521. break;
  522. case OP_POW:
  523. rval = SV->pop_back();
  524. lval = SV->pop_back();
  525. val = pow(lval,rval);
  526. SV->push_back(val);
  527. if(left->getType()==PNode_Type && right->getType()==PNode_Type)
  528. {
  529. std::fill_n(ret_vec,len,0);
  530. }
  531. else
  532. {
  533. hu = rval*pow(lval,(rval-1));
  534. hv = pow(lval,rval)*log(lval);
  535. if(left->getType()==PNode_Type)
  536. {
  537. double coeff = pow(log(lval),2)*pow(lval,rval);
  538. //calculate the first order derivatives
  539. for(unsigned int i =0;i<AutoDiff::num_var;++i)
  540. {
  541. ret_vec[i] = hu*lvec[i] + hv*rvec[i];
  542. }
  543. //calculate the second order
  544. index = AutoDiff::num_var;
  545. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  546. {
  547. for(unsigned int j=i;j<AutoDiff::num_var;++j)
  548. {
  549. ret_vec[index] = 0 + 0 + hv * rvec[index] + rvec[i] * coeff * rvec[j];
  550. ++index;
  551. }
  552. }
  553. }
  554. else if(right->getType()==PNode_Type)
  555. {
  556. double coeff = rval*(rval-1)*pow(lval,rval-2);
  557. //calculate the first order derivatives
  558. for(unsigned int i =0;i<AutoDiff::num_var;++i)
  559. {
  560. ret_vec[i] = hu*lvec[i] + hv*rvec[i];
  561. }
  562. //calculate the second order
  563. index = AutoDiff::num_var;
  564. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  565. {
  566. for(unsigned int j=i;j<AutoDiff::num_var;++j)
  567. {
  568. ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
  569. ++index;
  570. }
  571. }
  572. }
  573. else
  574. {
  575. assert(false);
  576. }
  577. }
  578. assert(index==len);
  579. break;
  580. case OP_SIN: //TODO should move to UnaryOPNode.cpp?
  581. assert(left!=NULL&&right==NULL);
  582. lval = SV->pop_back();
  583. val = sin(lval);
  584. SV->push_back(val);
  585. hu = cos(lval);
  586. double coeff;
  587. coeff = -val; //=sin(left->val); -- and avoid cross initialisation
  588. //calculate the first order derivatives
  589. for(unsigned int i =0;i<AutoDiff::num_var;++i)
  590. {
  591. ret_vec[i] = hu*lvec[i] + 0;
  592. }
  593. //calculate the second order
  594. index = AutoDiff::num_var;
  595. for(unsigned int i=0;i<AutoDiff::num_var;++i)
  596. {
  597. for(unsigned int j=i;j<AutoDiff::num_var;++j)
  598. {
  599. ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
  600. ++index;
  601. }
  602. }
  603. assert(index==len);
  604. break;
  605. default:
  606. cerr<<"op["<<op<<"] not yet implemented!";
  607. break;
  608. }
  609. }
  610. #endif
  611. string BinaryOPNode::toString(int level){
  612. ostringstream oss;
  613. string s(level,'\t');
  614. oss<<s<<"[BinaryOPNode]("<<op<<")";
  615. return oss.str();
  616. }
  617. } /* namespace AutoDiff */