nonblocking_test.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # (C) Copyright 2007
  2. # Andreas Kloeckner <inform -at- tiker.net>
  3. #
  4. # Use, modification and distribution is subject to the Boost Software
  5. # License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  6. # http://www.boost.org/LICENSE_1_0.txt)
  7. #
  8. # Authors: Andreas Kloeckner
  9. import boost.mpi as mpi
  10. import random
  11. import sys
  12. MAX_GENERATIONS = 20
  13. TAG_DEBUG = 0
  14. TAG_DATA = 1
  15. TAG_TERMINATE = 2
  16. TAG_PROGRESS_REPORT = 3
  17. class TagGroupListener:
  18. """Class to help listen for only a given set of tags.
  19. This is contrived: Typicallly you could just listen for
  20. mpi.any_tag and filter."""
  21. def __init__(self, comm, tags):
  22. self.tags = tags
  23. self.comm = comm
  24. self.active_requests = {}
  25. def wait(self):
  26. for tag in self.tags:
  27. if tag not in self.active_requests:
  28. self.active_requests[tag] = self.comm.irecv(tag=tag)
  29. requests = mpi.RequestList(self.active_requests.values())
  30. data, status, index = mpi.wait_any(requests)
  31. del self.active_requests[status.tag]
  32. return status, data
  33. def cancel(self):
  34. for r in self.active_requests.itervalues():
  35. r.cancel()
  36. #r.wait()
  37. self.active_requests = {}
  38. def rank0():
  39. sent_histories = (mpi.size-1)*15
  40. print "sending %d packets on their way" % sent_histories
  41. send_reqs = mpi.RequestList()
  42. for i in range(sent_histories):
  43. dest = random.randrange(1, mpi.size)
  44. send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
  45. mpi.wait_all(send_reqs)
  46. completed_histories = []
  47. progress_reports = {}
  48. dead_kids = []
  49. tgl = TagGroupListener(mpi.world,
  50. [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
  51. def is_complete():
  52. for i in progress_reports.values():
  53. if i != sent_histories:
  54. return False
  55. return len(dead_kids) == mpi.size-1
  56. while True:
  57. status, data = tgl.wait()
  58. if status.tag == TAG_DATA:
  59. #print "received completed history %s from %d" % (data, status.source)
  60. completed_histories.append(data)
  61. if len(completed_histories) == sent_histories:
  62. print "all histories received, exiting"
  63. for rank in range(1, mpi.size):
  64. mpi.world.send(rank, TAG_TERMINATE, None)
  65. elif status.tag == TAG_PROGRESS_REPORT:
  66. progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
  67. elif status.tag == TAG_DEBUG:
  68. print "[DBG %d] %s" % (status.source, data)
  69. elif status.tag == TAG_TERMINATE:
  70. dead_kids.append(status.source)
  71. else:
  72. print "unexpected tag %d from %d" % (status.tag, status.source)
  73. if is_complete():
  74. break
  75. print "OK"
  76. def comm_rank():
  77. while True:
  78. data, status = mpi.world.recv(return_status=True)
  79. if status.tag == TAG_DATA:
  80. mpi.world.send(0, TAG_PROGRESS_REPORT, data)
  81. data.append(mpi.rank)
  82. if len(data) >= MAX_GENERATIONS:
  83. dest = 0
  84. else:
  85. dest = random.randrange(1, mpi.size)
  86. mpi.world.send(dest, TAG_DATA, data)
  87. elif status.tag == TAG_TERMINATE:
  88. from time import sleep
  89. mpi.world.send(0, TAG_TERMINATE, 0)
  90. break
  91. else:
  92. print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
  93. def main():
  94. # this program sends around messages consisting of lists of visited nodes
  95. # randomly. After MAX_GENERATIONS, they are returned to rank 0.
  96. if mpi.rank == 0:
  97. rank0()
  98. else:
  99. comm_rank()
  100. if __name__ == "__main__":
  101. main()