A Discrete-Event Network Simulator
API
null-message-mpi-interface.cc
Go to the documentation of this file.
1 /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
2 /*
3  * Copyright 2013. Lawrence Livermore National Security, LLC.
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation;
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17  *
18  * Author: Steven Smith <smith84@llnl.gov>
19  *
20  */
21 
29 
32 #include "remote-channel-bundle.h"
33 
34 #include "ns3/mpi-receiver.h"
35 #include "ns3/node.h"
36 #include "ns3/node-list.h"
37 #include "ns3/net-device.h"
38 #include "ns3/nstime.h"
39 #include "ns3/simulator.h"
40 #include "ns3/log.h"
41 
42 #include <mpi.h>
43 
44 #include <iostream>
45 #include <iomanip>
46 #include <list>
47 
48 namespace ns3 {
49 
50 NS_LOG_COMPONENT_DEFINE ("NullMessageMpiInterface");
51 
52 NS_OBJECT_ENSURE_REGISTERED (NullMessageMpiInterface);
53 
62 {
63 public:
66 
70  uint8_t* GetBuffer ();
74  void SetBuffer (uint8_t* buffer);
78  MPI_Request* GetRequest ();
79 
80 private:
81 
85  uint8_t* m_buffer;
86 
90  MPI_Request m_request;
91 };
92 
97 const uint32_t NULL_MESSAGE_MAX_MPI_MSG_SIZE = 2000;
98 
100 {
101  m_buffer = 0;
102  m_request = 0;
103 }
104 
106 {
107  delete [] m_buffer;
108 }
109 
110 uint8_t*
112 {
113  return m_buffer;
114 }
115 
116 void
118 {
119  m_buffer = buffer;
120 }
121 
122 MPI_Request*
124 {
125  return &m_request;
126 }
127 
128 uint32_t NullMessageMpiInterface::g_sid = 0;
133 
134 std::list<NullMessageSentBuffer> NullMessageMpiInterface::g_pendingTx;
135 
136 MPI_Comm NullMessageMpiInterface::g_communicator = MPI_COMM_WORLD;
140 
141 TypeId
143 {
144  static TypeId tid = TypeId ("ns3::NullMessageMpiInterface")
145  .SetParent<Object> ()
146  .SetGroupName ("Mpi")
147  ;
148  return tid;
149 }
150 
152 {
153  NS_LOG_FUNCTION (this);
154 }
155 
157 {
158  NS_LOG_FUNCTION (this);
159 }
160 
161 void
163 {
164  NS_LOG_FUNCTION (this);
165 }
166 
167 uint32_t
169 {
171  return g_sid;
172 }
173 
174 uint32_t
176 {
178  return g_size;
179 }
180 
181 MPI_Comm
183 {
185  return g_communicator;
186 }
187 
188 bool
190 {
191  return g_enabled;
192 }
193 
194 void
195 NullMessageMpiInterface::Enable (int* pargc, char*** pargv)
196 {
197  NS_LOG_FUNCTION (this << *pargc);
198 
199  NS_ASSERT (g_enabled == false);
200 
201  // Initialize the MPI interface
202  MPI_Init (pargc, pargv);
203  Enable (MPI_COMM_WORLD);
204  g_mpiInitCalled = true;
205 }
206 
207 void
208 NullMessageMpiInterface::Enable (MPI_Comm communicator)
209 {
210  NS_LOG_FUNCTION (this);
211 
212  NS_ASSERT (g_enabled == false);
213 
214  // Standard MPI practice is to duplicate the communicator for
215  // library to use. Library communicates in isolated communication
216  // context.
217  MPI_Comm_dup (communicator, &g_communicator);
218  g_freeCommunicator = true;
219 
220  // SystemId and Size are unit32_t in interface but MPI uses int so convert.
221  int mpiSystemId;
222  int mpiSize;
223  MPI_Comm_rank (g_communicator, &mpiSystemId);
224  MPI_Comm_size (g_communicator, &mpiSize);
225 
226  g_sid = mpiSystemId;
227  g_size = mpiSize;
228 
229  g_enabled = true;
230 
231  MPI_Barrier(g_communicator);
232 }
233 
234 void
236 {
239 
241 
242  // Post a non-blocking receive for all peers
243  g_requests = new MPI_Request[g_numNeighbors];
244  g_pRxBuffers = new char*[g_numNeighbors];
245  int index = 0;
246  for (uint32_t rank = 0; rank < g_size; ++rank)
247  {
249  if (bundle)
250  {
251  g_pRxBuffers[index] = new char[NULL_MESSAGE_MAX_MPI_MSG_SIZE];
252  MPI_Irecv (g_pRxBuffers[index], NULL_MESSAGE_MAX_MPI_MSG_SIZE, MPI_CHAR, rank, 0,
253  g_communicator, &g_requests[index]);
254  ++index;
255  }
256  }
257 }
258 
259 void
260 NullMessageMpiInterface::SendPacket (Ptr<Packet> p, const Time& rxTime, uint32_t node, uint32_t dev)
261 {
262  NS_LOG_FUNCTION (this << p << rxTime.GetTimeStep () << node << dev);
263 
265 
266  // Find the system id for the destination node
267  Ptr<Node> destNode = NodeList::GetNode (node);
268  uint32_t nodeSysId = destNode->GetSystemId ();
269 
270  NullMessageSentBuffer sendBuf;
271  g_pendingTx.push_back (sendBuf);
272  std::list<NullMessageSentBuffer>::reverse_iterator iter = g_pendingTx.rbegin (); // Points to the last element
273 
274  uint32_t serializedSize = p->GetSerializedSize ();
275  uint32_t bufferSize = serializedSize + ( 2 * sizeof (uint64_t) ) + ( 2 * sizeof (uint32_t) );
276  uint8_t* buffer = new uint8_t[bufferSize];
277  iter->SetBuffer (buffer);
278  // Add the time, dest node and dest device
279  uint64_t t = rxTime.GetInteger ();
280  uint64_t* pTime = reinterpret_cast <uint64_t *> (buffer);
281  *pTime++ = t;
282 
283  Time guarantee_update = NullMessageSimulatorImpl::GetInstance ()->CalculateGuaranteeTime (nodeSysId);
284  *pTime++ = guarantee_update.GetTimeStep ();
285 
286  uint32_t* pData = reinterpret_cast<uint32_t *> (pTime);
287  *pData++ = node;
288  *pData++ = dev;
289  // Serialize the packet
290  p->Serialize (reinterpret_cast<uint8_t *> (pData), serializedSize);
291 
292  MPI_Isend (reinterpret_cast<void *> (iter->GetBuffer ()), bufferSize, MPI_CHAR, nodeSysId,
293  0, g_communicator, (iter->GetRequest ()));
294 
296 }
297 
298 void
300 {
301  NS_LOG_FUNCTION (guarantee_update.GetTimeStep () << bundle);
302 
304 
305  NullMessageSentBuffer sendBuf;
306  g_pendingTx.push_back (sendBuf);
307  std::list<NullMessageSentBuffer>::reverse_iterator iter = g_pendingTx.rbegin (); // Points to the last element
308 
309  uint32_t bufferSize = 2 * sizeof (uint64_t) + 2 * sizeof (uint32_t);
310  uint8_t* buffer = new uint8_t[bufferSize];
311  iter->SetBuffer (buffer);
312  // Add the time, dest node and dest device
313  uint64_t* pTime = reinterpret_cast <uint64_t *> (buffer);
314  *pTime++ = 0;
315  *pTime++ = guarantee_update.GetInteger ();
316  uint32_t* pData = reinterpret_cast<uint32_t *> (pTime);
317  *pData++ = 0;
318  *pData++ = 0;
319 
320  // Find the system id for the destination MPI rank
321  uint32_t nodeSysId = bundle->GetSystemId ();
322 
323  MPI_Isend (reinterpret_cast<void *> (iter->GetBuffer ()), bufferSize, MPI_CHAR, nodeSysId,
324  0, g_communicator, (iter->GetRequest ()));
325 }
326 
327 void
329 {
331 
332  ReceiveMessages(true);
333 }
334 
335 
336 void
338 {
340 
341  ReceiveMessages(false);
342 }
343 
344 void
346 {
347  NS_LOG_FUNCTION (blocking);
348 
350 
351  // stop flag set to true when no more messages are found to
352  // process.
353  bool stop = false;
354 
355 
356  if (!g_numNeighbors) {
357  // Not communicating with anyone.
358  return;
359  }
360 
361  do
362  {
363  int messageReceived = 0;
364  int index = 0;
365  MPI_Status status;
366 
367  if (blocking)
368  {
369  MPI_Waitany (g_numNeighbors, g_requests, &index, &status);
370  messageReceived = 1; /* Wait always implies message was received */
371  stop = true;
372  }
373  else
374  {
375  MPI_Testany (g_numNeighbors, g_requests, &index, &messageReceived, &status);
376  }
377 
378  if (messageReceived)
379  {
380  int count;
381  MPI_Get_count (&status, MPI_CHAR, &count);
382 
383  // Get the meta data first
384  uint64_t* pTime = reinterpret_cast<uint64_t *> (g_pRxBuffers[index]);
385  uint64_t time = *pTime++;
386  uint64_t guaranteeUpdate = *pTime++;
387 
388  uint32_t* pData = reinterpret_cast<uint32_t *> (pTime);
389  uint32_t node = *pData++;
390  uint32_t dev = *pData++;
391 
392  Time rxTime (time);
393 
394  // rxtime == 0 means this is a Null Message
395  if (rxTime > Time (0))
396  {
397  count -= sizeof (time) + sizeof (guaranteeUpdate) + sizeof (node) + sizeof (dev);
398 
399  Ptr<Packet> p = Create<Packet> (reinterpret_cast<uint8_t *> (pData), count, true);
400 
401  // Find the correct node/device to schedule receive event
402  Ptr<Node> pNode = NodeList::GetNode (node);
403  Ptr<MpiReceiver> pMpiRec = 0;
404  uint32_t nDevices = pNode->GetNDevices ();
405  for (uint32_t i = 0; i < nDevices; ++i)
406  {
407  Ptr<NetDevice> pThisDev = pNode->GetDevice (i);
408  if (pThisDev->GetIfIndex () == dev)
409  {
410  pMpiRec = pThisDev->GetObject<MpiReceiver> ();
411  break;
412  }
413  }
414  NS_ASSERT (pNode && pMpiRec);
415 
416  // Schedule the rx event
417  Simulator::ScheduleWithContext (pNode->GetId (), rxTime - Simulator::Now (),
418  &MpiReceiver::Receive, pMpiRec, p);
419 
420  }
421 
422  // Update guarantee time for both packet receives and Null Messages.
423  Ptr<RemoteChannelBundle> bundle = RemoteChannelBundleManager::Find (status.MPI_SOURCE);
424  NS_ASSERT (bundle);
425 
426  bundle->SetGuaranteeTime (Time (guaranteeUpdate));
427 
428  // Re-queue the next read
429  MPI_Irecv (g_pRxBuffers[index], NULL_MESSAGE_MAX_MPI_MSG_SIZE, MPI_CHAR, status.MPI_SOURCE, 0,
430  g_communicator, &g_requests[index]);
431 
432  }
433  else
434  {
435  // if non-blocking and no message received in testany then stop message loop
436  stop = true;
437  }
438  }
439  while (!stop);
440 }
441 
442 void
444 {
446 
448 
449  std::list<NullMessageSentBuffer>::iterator iter = g_pendingTx.begin ();
450  while (iter != g_pendingTx.end ())
451  {
452  MPI_Status status;
453  int flag = 0;
454  MPI_Test (iter->GetRequest (), &flag, &status);
455  std::list<NullMessageSentBuffer>::iterator current = iter; // Save current for erasing
456  ++iter; // Advance to next
457  if (flag)
458  { // This message is complete
459  g_pendingTx.erase (current);
460  }
461  }
462 }
463 
464 void
466 {
467  NS_LOG_FUNCTION (this);
468 
469  if (g_enabled)
470  {
471  for (std::list<NullMessageSentBuffer>::iterator iter = g_pendingTx.begin ();
472  iter != g_pendingTx.end ();
473  ++iter)
474  {
475  MPI_Cancel (iter->GetRequest ());
476  MPI_Request_free (iter->GetRequest ());
477  }
478 
479  for (uint32_t i = 0; i < g_numNeighbors; ++i)
480  {
481  MPI_Cancel (&g_requests[i]);
482  MPI_Request_free (&g_requests[i]);
483  }
484 
485 
486  for (uint32_t i = 0; i < g_numNeighbors; ++i)
487  {
488  delete [] g_pRxBuffers[i];
489  }
490  delete [] g_pRxBuffers;
491  delete [] g_requests;
492 
493  g_pendingTx.clear ();
494 
495 
496  if (g_freeCommunicator)
497  {
498  MPI_Comm_free (&g_communicator);
499  g_freeCommunicator = false;
500  }
501 
502  if (g_mpiInitCalled)
503  {
504  int flag = 0;
505  MPI_Initialized (&flag);
506  if (flag)
507  {
508  MPI_Finalize ();
509  }
510  else
511  {
512  NS_FATAL_ERROR ("Cannot disable MPI environment without Initializing it first");
513  }
514  }
515 
516  g_enabled = false;
517  g_mpiInitCalled = false;
518  }
519  else
520  {
521  NS_FATAL_ERROR ("Cannot disable MPI environment without Initializing it first");
522  }
523 }
524 
525 } // namespace ns3
virtual void Destroy()
Deletes storage used by the parallel environment.
Time CalculateGuaranteeTime(uint32_t systemId)
static MPI_Request * g_requests
Pending non-blocking receives.
Simulation virtual time values and global simulation resolution.
Definition: nstime.h:103
Smart pointer class similar to boost::intrusive_ptr.
Definition: ptr.h:73
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by "...
uint32_t GetId(void) const
Definition: node.cc:109
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition: object-base.h:45
const uint32_t NULL_MESSAGE_MAX_MPI_MSG_SIZE
maximum MPI message size for easy buffer creation
int64_t GetInteger(void) const
Get the raw time value, in the current resolution unit.
Definition: nstime.h:424
static Ptr< Node > GetNode(uint32_t n)
Definition: node-list.cc:241
Ptr< NetDevice > GetDevice(uint32_t index) const
Retrieve the index-th NetDevice associated to this node.
Definition: node.cc:144
uint32_t GetSerializedSize(void) const
Returns number of bytes required for packet serialization.
Definition: packet.cc:585
virtual uint32_t GetSystemId()
Get the id number of this rank.
Declaration of class ns3::NullMessageSimulatorImpl.
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file...
Definition: assert.h:67
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition: log.h:205
static void ReceiveMessagesNonBlocking()
Non-blocking check for received messages complete.
#define NS_FATAL_ERROR(msg)
Report a fatal error with a message and terminate.
Definition: fatal-error.h:165
void RescheduleNullMessageEvent(Ptr< RemoteChannelBundle > bundle)
#define NS_LOG_FUNCTION_NOARGS()
Output the name of the function.
static bool g_freeCommunicator
Did we create the communicator? Have to free it.
virtual void SendPacket(Ptr< Packet > p, const Time &rxTime, uint32_t node, uint32_t dev)
Send a packet to a remote node.
static NullMessageSimulatorImpl * GetInstance(void)
static void TestSendComplete()
Check for completed sends.
virtual MPI_Comm GetCommunicator()
Return the communicator used to run ns-3.
virtual uint32_t GetSize()
Get the number of ranks used by ns-3.
virtual void Enable(int *pargc, char ***pargv)
Setup the parallel communication interface.
void Receive(Ptr< Packet > p)
Direct an incoming packet to the device Receive() method.
Definition: mpi-receiver.cc:51
static void SendNullMessage(const Time &guaranteeUpdate, Ptr< RemoteChannelBundle > bundle)
Send a Null Message to across the specified bundle.
Class to aggregate to a NetDevice if it supports MPI capability.
Definition: mpi-receiver.h:47
static void ReceiveMessages(bool blocking=false)
Check for received messages complete.
virtual void Disable()
Clean up the ns-3 parallel communications interface.
static void ScheduleWithContext(uint32_t context, Time const &delay, FUNC f, Ts &&... args)
Schedule an event with the given context.
Definition: simulator.h:572
Every class exported by the ns3 library is enclosed in the ns3 namespace.
static void ReceiveMessagesBlocking()
Blocking message receive.
uint32_t GetSystemId(void) const
Definition: node.cc:123
static Ptr< RemoteChannelBundle > Find(uint32_t systemId)
Get the bundle corresponding to a remote rank.
virtual bool IsEnabled()
Returns enabled state of parallel environment.
static MPI_Comm g_communicator
MPI communicator being used for ns-3 tasks.
Declaration of class ns3::RemoteChannelBundleManager.
static Time Now(void)
Return the current simulation virtual time.
Definition: simulator.cc:195
static TypeId GetTypeId(void)
Register this type.
static bool g_mpiInitCalled
Has MPI Init been called by this interface.
static bool g_enabled
Has this interface been enabled.
uint8_t * m_buffer
Buffer for send.
Declaration of class ns3::RemoteChannelBundle.
Declaration of classes ns3::NullMessageSentBuffer and ns3::NullMessageMpiInterface.
static std::size_t Size(void)
Get the number of ns-3 channels in this bundle.
static uint32_t g_size
Size of the MPI COM_WORLD group.
Non-blocking send buffers for Null Message implementation.
static char ** g_pRxBuffers
Data buffers for non-blocking receives.
A base class which provides memory management and object aggregation.
Definition: object.h:87
static void InitializeSendReceiveBuffers(void)
Initialize send and receive buffers.
static uint32_t g_sid
System ID (rank) for this task.
static uint32_t g_numNeighbors
Number of neighbor tasks, tasks that this task shares a link with.
a unique identifier for an interface.
Definition: type-id.h:58
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition: type-id.cc:923
static std::list< NullMessageSentBuffer > g_pendingTx
List of pending non-blocking sends.
uint32_t GetNDevices(void) const
Definition: node.cc:152
uint32_t Serialize(uint8_t *buffer, uint32_t maxSize) const
Serialize a packet, tags, and metadata into a byte buffer.
Definition: packet.cc:638
MPI_Request m_request
MPI request posted for the send.
int64_t GetTimeStep(void) const
Get the raw time value, in the current resolution unit.
Definition: nstime.h:416