00001
00002
00003 #ifndef DUNE_MPICOLLECTIVECOMMUNICATION_HH
00004 #define DUNE_MPICOLLECTIVECOMMUNICATION_HH
00005
00014 #include <iostream>
00015 #include <complex>
00016 #include <algorithm>
00017 #include <functional>
00018 #include <memory>
00019
00020 #include <dune/common/exceptions.hh>
00021 #include <dune/common/binaryfunctions.hh>
00022
00023 #include "collectivecommunication.hh"
00024 #include "mpitraits.hh"
00025
00026 #if HAVE_MPI
00027
00028 #include <mpi.h>
00029
00030 namespace Dune
00031 {
00032
00033
00034
00035
00036
00037
00038 template<typename Type, typename BinaryFunction>
00039 class Generic_MPI_Op
00040 {
00041
00042 public:
00043 static MPI_Op get ()
00044 {
00045 if (!op)
00046 {
00047 op = std::shared_ptr<MPI_Op>(new MPI_Op);
00048 MPI_Op_create((void (*)(void*, void*, int*, MPI_Datatype*))&operation,true,op.get());
00049 }
00050 return *op;
00051 }
00052 private:
00053 static void operation (Type *in, Type *inout, int *len, MPI_Datatype*)
00054 {
00055 BinaryFunction func;
00056
00057 for (int i=0; i< *len; ++i, ++in, ++inout) {
00058 Type temp;
00059 temp = func(*in, *inout);
00060 *inout = temp;
00061 }
00062 }
00063 Generic_MPI_Op () {}
00064 Generic_MPI_Op (const Generic_MPI_Op& ) {}
00065 static std::shared_ptr<MPI_Op> op;
00066 };
00067
00068
00069 template<typename Type, typename BinaryFunction>
00070 std::shared_ptr<MPI_Op> Generic_MPI_Op<Type,BinaryFunction>::op = std::shared_ptr<MPI_Op>(static_cast<MPI_Op*>(0));
00071
00072 #define ComposeMPIOp(type,func,op) \
00073 template<> \
00074 class Generic_MPI_Op<type, func<type> >{ \
00075 public:\
00076 static MPI_Op get(){ \
00077 return op; \
00078 } \
00079 private:\
00080 Generic_MPI_Op () {}\
00081 Generic_MPI_Op (const Generic_MPI_Op & ) {}\
00082 }
00083
00084
00085 ComposeMPIOp(char, std::plus, MPI_SUM);
00086 ComposeMPIOp(unsigned char, std::plus, MPI_SUM);
00087 ComposeMPIOp(short, std::plus, MPI_SUM);
00088 ComposeMPIOp(unsigned short, std::plus, MPI_SUM);
00089 ComposeMPIOp(int, std::plus, MPI_SUM);
00090 ComposeMPIOp(unsigned int, std::plus, MPI_SUM);
00091 ComposeMPIOp(long, std::plus, MPI_SUM);
00092 ComposeMPIOp(unsigned long, std::plus, MPI_SUM);
00093 ComposeMPIOp(float, std::plus, MPI_SUM);
00094 ComposeMPIOp(double, std::plus, MPI_SUM);
00095 ComposeMPIOp(long double, std::plus, MPI_SUM);
00096
00097 ComposeMPIOp(char, std::multiplies, MPI_PROD);
00098 ComposeMPIOp(unsigned char, std::multiplies, MPI_PROD);
00099 ComposeMPIOp(short, std::multiplies, MPI_PROD);
00100 ComposeMPIOp(unsigned short, std::multiplies, MPI_PROD);
00101 ComposeMPIOp(int, std::multiplies, MPI_PROD);
00102 ComposeMPIOp(unsigned int, std::multiplies, MPI_PROD);
00103 ComposeMPIOp(long, std::multiplies, MPI_PROD);
00104 ComposeMPIOp(unsigned long, std::multiplies, MPI_PROD);
00105 ComposeMPIOp(float, std::multiplies, MPI_PROD);
00106 ComposeMPIOp(double, std::multiplies, MPI_PROD);
00107 ComposeMPIOp(long double, std::multiplies, MPI_PROD);
00108
00109 ComposeMPIOp(char, Min, MPI_MIN);
00110 ComposeMPIOp(unsigned char, Min, MPI_MIN);
00111 ComposeMPIOp(short, Min, MPI_MIN);
00112 ComposeMPIOp(unsigned short, Min, MPI_MIN);
00113 ComposeMPIOp(int, Min, MPI_MIN);
00114 ComposeMPIOp(unsigned int, Min, MPI_MIN);
00115 ComposeMPIOp(long, Min, MPI_MIN);
00116 ComposeMPIOp(unsigned long, Min, MPI_MIN);
00117 ComposeMPIOp(float, Min, MPI_MIN);
00118 ComposeMPIOp(double, Min, MPI_MIN);
00119 ComposeMPIOp(long double, Min, MPI_MIN);
00120
00121 ComposeMPIOp(char, Max, MPI_MAX);
00122 ComposeMPIOp(unsigned char, Max, MPI_MAX);
00123 ComposeMPIOp(short, Max, MPI_MAX);
00124 ComposeMPIOp(unsigned short, Max, MPI_MAX);
00125 ComposeMPIOp(int, Max, MPI_MAX);
00126 ComposeMPIOp(unsigned int, Max, MPI_MAX);
00127 ComposeMPIOp(long, Max, MPI_MAX);
00128 ComposeMPIOp(unsigned long, Max, MPI_MAX);
00129 ComposeMPIOp(float, Max, MPI_MAX);
00130 ComposeMPIOp(double, Max, MPI_MAX);
00131 ComposeMPIOp(long double, Max, MPI_MAX);
00132
00133 #undef ComposeMPIOp
00134
00135
00136
00137
00138
00139
00140
00144 template<>
00145 class CollectiveCommunication<MPI_Comm>
00146 {
00147 public:
00149 CollectiveCommunication (const MPI_Comm& c = MPI_COMM_WORLD)
00150 : communicator(c)
00151 {
00152 if(communicator!=MPI_COMM_NULL) {
00153 int initialized = 0;
00154 MPI_Initialized(&initialized);
00155 if (!initialized)
00156 DUNE_THROW(ParallelError,"You must call MPIHelper::instance(argc,argv) in your main() function before using the MPI CollectiveCommunication!");
00157 MPI_Comm_rank(communicator,&me);
00158 MPI_Comm_size(communicator,&procs);
00159 }else{
00160 procs=0;
00161 me=-1;
00162 }
00163 }
00164
00166 int rank () const
00167 {
00168 return me;
00169 }
00170
00172 int size () const
00173 {
00174 return procs;
00175 }
00176
00178 template<typename T>
00179 T sum (T& in) const
00180 {
00181 T out;
00182 allreduce<std::plus<T> >(&in,&out,1);
00183 return out;
00184 }
00185
00187 template<typename T>
00188 int sum (T* inout, int len) const
00189 {
00190 return allreduce<std::plus<T> >(inout,len);
00191 }
00192
00194 template<typename T>
00195 T prod (T& in) const
00196 {
00197 T out;
00198 allreduce<std::multiplies<T> >(&in,&out,1);
00199 return out;
00200 }
00201
00203 template<typename T>
00204 int prod (T* inout, int len) const
00205 {
00206 return allreduce<std::multiplies<T> >(inout,len);
00207 }
00208
00210 template<typename T>
00211 T min (T& in) const
00212 {
00213 T out;
00214 allreduce<Min<T> >(&in,&out,1);
00215 return out;
00216 }
00217
00219 template<typename T>
00220 int min (T* inout, int len) const
00221 {
00222 return allreduce<Min<T> >(inout,len);
00223 }
00224
00225
00227 template<typename T>
00228 T max (T& in) const
00229 {
00230 T out;
00231 allreduce<Max<T> >(&in,&out,1);
00232 return out;
00233 }
00234
00236 template<typename T>
00237 int max (T* inout, int len) const
00238 {
00239 return allreduce<Max<T> >(inout,len);
00240 }
00241
00243 int barrier () const
00244 {
00245 return MPI_Barrier(communicator);
00246 }
00247
00249 template<typename T>
00250 int broadcast (T* inout, int len, int root) const
00251 {
00252 return MPI_Bcast(inout,len,MPITraits<T>::getType(),root,communicator);
00253 }
00254
00257 template<typename T>
00258 int gather (T* in, T* out, int len, int root) const
00259 {
00260 return MPI_Gather(in,len,MPITraits<T>::getType(),
00261 out,len,MPITraits<T>::getType(),
00262 root,communicator);
00263 }
00264
00266 template<typename T>
00267 int gatherv (T* in, int sendlen, T* out, int* recvlen, int* displ, int root) const
00268 {
00269 return MPI_Gatherv(in,sendlen,MPITraits<T>::getType(),
00270 out,recvlen,displ,MPITraits<T>::getType(),
00271 root,communicator);
00272 }
00273
00276 template<typename T>
00277 int scatter (T* send, T* recv, int len, int root) const
00278 {
00279 return MPI_Scatter(send,len,MPITraits<T>::getType(),
00280 recv,len,MPITraits<T>::getType(),
00281 root,communicator);
00282 }
00283
00285 template<typename T>
00286 int scatterv (T* send, int* sendlen, int* displ, T* recv, int recvlen, int root) const
00287 {
00288 return MPI_Scatterv(send,sendlen,displ,MPITraits<T>::getType(),
00289 recv,recvlen,MPITraits<T>::getType(),
00290 root,communicator);
00291 }
00292
00293
00294 operator MPI_Comm () const
00295 {
00296 return communicator;
00297 }
00298
00300 template<typename T, typename T1>
00301 int allgather(T* sbuf, int count, T1* rbuf) const
00302 {
00303 return MPI_Allgather(sbuf, count, MPITraits<T>::getType(),
00304 rbuf, count, MPITraits<T1>::getType(),
00305 communicator);
00306 }
00307
00309 template<typename T>
00310 int allgatherv (T* in, int sendlen, T* out, int* recvlen, int* displ) const
00311 {
00312 return MPI_Allgatherv(in,sendlen,MPITraits<T>::getType(),
00313 out,recvlen,displ,MPITraits<T>::getType(),
00314 communicator);
00315 }
00316
00318 template<typename BinaryFunction, typename Type>
00319 int allreduce(Type* inout, int len) const
00320 {
00321 Type* out = new Type[len];
00322 int ret = allreduce<BinaryFunction>(inout,out,len);
00323 std::copy(out, out+len, inout);
00324 delete[] out;
00325 return ret;
00326 }
00327
00329 template<typename BinaryFunction, typename Type>
00330 int allreduce(Type* in, Type* out, int len) const
00331 {
00332 return MPI_Allreduce(in, out, len, MPITraits<Type>::getType(),
00333 (Generic_MPI_Op<Type, BinaryFunction>::get()),communicator);
00334 }
00335
00336 private:
00337 MPI_Comm communicator;
00338 int me;
00339 int procs;
00340 };
00341 }
00342
00343 #endif
00344 #endif