Actual source code: curand.c

  1: #include <petsc/private/deviceimpl.h>
  2: #include <petsc/private/randomimpl.h>
  3: #include <curand.h>

  5: typedef struct {
  6:   curandGenerator_t gen;
  7: } PetscRandom_CURAND;

  9: PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
 10: {
 11:   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;

 13:   curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);
 14:   return 0;
 15: }

 17: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);

 19: PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
 20: {
 21:   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
 22:   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */

 24: #if defined(PETSC_USE_REAL_SINGLE)
 25:   curandGenerateUniform(curand->gen,val,nn);
 26: #else
 27:   curandGenerateUniformDouble(curand->gen,val,nn);
 28: #endif
 29:   if (r->iset) {
 30:     PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));
 31:   }
 32:   return 0;
 33: }

 35: PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
 36: {
 37: #if defined(PETSC_USE_COMPLEX)
 38:   /* pass negative size to flag complex scaling (if needed) */
 39:   PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);
 40: #else
 41:   PetscRandomGetValuesReal_CURAND(r,n,val);
 42: #endif
 43:   return 0;
 44: }

 46: PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
 47: {
 48:   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;

 50:   curandDestroyGenerator(curand->gen);
 51:   PetscFree(r->data);
 52:   return 0;
 53: }

 55: static struct _PetscRandomOps PetscRandomOps_Values = {
 56:   PetscDesignatedInitializer(seed,PetscRandomSeed_CURAND),
 57:   PetscDesignatedInitializer(getvalue,NULL),
 58:   PetscDesignatedInitializer(getvaluereal,NULL),
 59:   PetscDesignatedInitializer(getvalues,PetscRandomGetValues_CURAND),
 60:   PetscDesignatedInitializer(getvaluesreal,PetscRandomGetValuesReal_CURAND),
 61:   PetscDesignatedInitializer(destroy,PetscRandomDestroy_CURAND),
 62: };

 64: /*MC
 65:    PETSCCURAND - access to the CUDA random number generator

 67:   Level: beginner

 69: .seealso: PetscRandomCreate(), PetscRandomSetType()
 70: M*/

 72: PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
 73: {
 74:   PetscRandom_CURAND *curand;

 76:   PetscDeviceInitialize(PETSC_DEVICE_CUDA);
 77:   PetscNewLog(r,&curand);
 78:   curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);
 79:   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
 80:   curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);
 81:   PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));
 82:   PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);
 83:   r->data = curand;
 84:   r->seed = 1234ULL; /* taken from example */
 85:   PetscRandomSeed_CURAND(r);
 86:   return 0;
 87: }