Actual source code: bjkokkos.kokkos.cxx
1: #include <petscvec_kokkos.hpp>
2: #include <petsc/private/pcimpl.h>
3: #include <petsc/private/kspimpl.h>
4: #include <petscksp.h>
5: #include "petscsection.h"
6: #include <petscdmcomposite.h>
7: #include <Kokkos_Core.hpp>
9: typedef Kokkos::TeamPolicy<>::member_type team_member;
11: #include <../src/mat/impls/aij/seq/aij.h>
12: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
14: #define PCBJKOKKOS_SHARED_LEVEL 1
15: #define PCBJKOKKOS_VEC_SIZE 16
16: #define PCBJKOKKOS_TEAM_SIZE 16
17: #define PCBJKOKKOS_VERBOSE_LEVEL 0
19: typedef enum {BATCH_KSP_BICG_IDX,BATCH_KSP_TFQMR_IDX,BATCH_KSP_GMRES_IDX,NUM_BATCH_TYPES} KSPIndex;
20: typedef struct {
21: Vec vec_diag;
22: PetscInt nBlocks; /* total number of blocks */
23: PetscInt n; // cache host version of d_bid_eqOffset_k[nBlocks]
24: KSP ksp; // Used just for options. Should have one for each block
25: Kokkos::View<PetscInt*, Kokkos::LayoutRight> *d_bid_eqOffset_k;
26: Kokkos::View<PetscScalar*, Kokkos::LayoutRight> *d_idiag_k;
27: Kokkos::View<PetscInt*> *d_isrow_k;
28: Kokkos::View<PetscInt*> *d_isicol_k;
29: KSPIndex ksp_type_idx;
30: PetscInt nwork;
31: PetscInt const_block_size; // used to decide to use shared memory for work vectors
32: PetscInt *dm_Nf; // Number of fields in each DM
33: PetscInt num_dms;
34: // diagnostics
35: PetscBool reason;
36: PetscBool monitor;
37: PetscInt batch_target;
38: } PC_PCBJKOKKOS;
40: static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
41: {
42: const char *prefix;
43: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
44: DM dm;
46: KSPCreate(PetscObjectComm((PetscObject)pc),&jac->ksp);
47: KSPSetErrorIfNotConverged(jac->ksp,pc->erroriffailure);
48: PetscObjectIncrementTabLevel((PetscObject)jac->ksp,(PetscObject)pc,1);
49: PCGetOptionsPrefix(pc,&prefix);
50: KSPSetOptionsPrefix(jac->ksp,prefix);
51: KSPAppendOptionsPrefix(jac->ksp,"pc_bjkokkos_");
52: PCGetDM(pc,&dm);
53: if (dm) {
54: KSPSetDM(jac->ksp, dm);
55: KSPSetDMActive(jac->ksp, PETSC_FALSE);
56: }
57: jac->reason = PETSC_FALSE;
58: jac->monitor = PETSC_FALSE;
59: jac->batch_target = 0;
60: return 0;
61: }
63: // y <-- Ax
64: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
65: {
66: Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
67: int rowa = ic[rowb];
68: int n = glb_Aai[rowa+1] - glb_Aai[rowa];
69: const PetscInt *aj = glb_Aaj + glb_Aai[rowa];
70: const PetscScalar *aa = glb_Aaa + glb_Aai[rowa];
71: PetscScalar sum;
72: Kokkos::parallel_reduce(Kokkos::ThreadVectorRange (team, n), [=] (const int i, PetscScalar& lsum) {
73: lsum += aa[i] * x_loc[r[aj[i]]-start];
74: }, sum);
75: Kokkos::single(Kokkos::PerThread (team),[=]() {y_loc[rowb-start] = sum;});
76: });
77: team.team_barrier();
78: return 0;
79: }
81: // temp buffer per thread with reduction at end?
82: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
83: {
84: Kokkos::parallel_for(Kokkos::TeamVectorRange(team,end-start), [=] (int i) { y_loc[i] = 0;});
85: team.team_barrier();
86: Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
87: int rowa = ic[rowb];
88: int n = glb_Aai[rowa+1] - glb_Aai[rowa];
89: const PetscInt *aj = glb_Aaj + glb_Aai[rowa];
90: const PetscScalar *aa = glb_Aaa + glb_Aai[rowa];
91: const PetscScalar xx = x_loc[rowb-start]; // rowb = ic[rowa] = ic[r[rowb]]
92: Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,n), [=] (const int &i) {
93: PetscScalar val = aa[i] * xx;
94: Kokkos::atomic_fetch_add(&y_loc[r[aj[i]]-start], val);
95: });
96: });
97: team.team_barrier();
98: return 0;
99: }
101: typedef struct Batch_MetaData_TAG
102: {
103: PetscInt flops;
104: PetscInt its;
105: KSPConvergedReason reason;
106: }Batch_MetaData;
108: // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
109: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
110: {
111: using Kokkos::parallel_reduce;
112: using Kokkos::parallel_for;
113: int Nblk = end-start, i,m;
114: PetscReal dp,dpold,w,dpest,tau,psi,cm,r0;
115: PetscScalar *ptr = work_space, rho,rhoold,a,s,b,eta,etaold,psiold,cf,dpi;
116: const PetscScalar *Diag = &glb_idiag[start];
117: PetscScalar *XX = ptr; ptr += stride;
118: PetscScalar *R = ptr; ptr += stride;
119: PetscScalar *RP = ptr; ptr += stride;
120: PetscScalar *V = ptr; ptr += stride;
121: PetscScalar *T = ptr; ptr += stride;
122: PetscScalar *Q = ptr; ptr += stride;
123: PetscScalar *P = ptr; ptr += stride;
124: PetscScalar *U = ptr; ptr += stride;
125: PetscScalar *D = ptr; ptr += stride;
126: PetscScalar *AUQ = V;
128: // init: get b, zero x
129: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
130: int rowa = ic[rowb];
131: R[rowb-start] = glb_b[rowa];
132: XX[rowb-start] = 0;
133: });
134: team.team_barrier();
135: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
136: team.team_barrier();
137: r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
138: // diagnostics
139: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
140: if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
141: #endif
142: if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
143: if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
145: /* Make the initial Rp = R */
146: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {RP[idx] = R[idx];});
147: team.team_barrier();
148: /* Set the initial conditions */
149: etaold = 0.0;
150: psiold = 0.0;
151: tau = dp;
152: dpold = dp;
154: /* rhoold = (r,rp) */
155: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rhoold);
156: team.team_barrier();
157: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx]; P[idx] = R[idx]; T[idx] = Diag[idx]*P[idx]; D[idx] = 0;});
158: team.team_barrier();
159: MatMult (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
161: i=0;
162: do {
163: /* s <- (v,rp) */
164: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += V[idx]*PetscConj(RP[idx]);}, s);
165: team.team_barrier();
166: a = rhoold / s; /* a <- rho / s */
167: /* q <- u - a v VecWAXPY(w,alpha,x,y): w = alpha x + y. */
168: /* t <- u + q */
169: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Q[idx] = U[idx] - a*V[idx]; T[idx] = U[idx] + Q[idx];});
170: team.team_barrier();
171: // KSP_PCApplyBAorAB
172: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*T[idx]; });
173: team.team_barrier();
174: MatMult (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,AUQ);
175: /* r <- r - a K (u + q) */
176: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {R[idx] = R[idx] - a*AUQ[idx]; });
177: team.team_barrier();
178: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
179: team.team_barrier();
180: dp = PetscSqrtReal(PetscRealPart(dpi));
181: for (m=0; m<2; m++) {
182: if (!m) w = PetscSqrtReal(dp*dpold);
183: else w = dp;
184: psi = w / tau;
185: cm = 1.0 / PetscSqrtReal(1.0 + psi * psi);
186: tau = tau * psi * cm;
187: eta = cm * cm * a;
188: cf = psiold * psiold * etaold / a;
189: if (!m) {
190: /* D = U + cf D */
191: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = U[idx] + cf*D[idx]; });
192: } else {
193: /* D = Q + cf D */
194: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = Q[idx] + cf*D[idx]; });
195: }
196: team.team_barrier();
197: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + eta*D[idx]; });
198: team.team_barrier();
199: dpest = PetscSqrtReal(2*i + m + 2.0) * tau;
200: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
201: if (monitor && m==1) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dpest);});
202: #endif
203: if (dpest < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
204: if (dpest/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
205: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
206: if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dpest,r0);}); goto done;}
207: #else
208: if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
209: #endif
210: if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
212: etaold = eta;
213: psiold = psi;
214: }
216: /* rho <- (r,rp) */
217: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rho);
218: team.team_barrier();
219: b = rho / rhoold; /* b <- rho / rhoold */
220: /* u <- r + b q */
221: /* p <- u + b(q + b p) */
222: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx] + b*Q[idx]; Q[idx] = Q[idx] + b*P[idx]; P[idx] = U[idx] + b*Q[idx];});
223: /* v <- K p */
224: team.team_barrier();
225: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*P[idx]; });
226: team.team_barrier();
227: MatMult (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
229: rhoold = rho;
230: dpold = dp;
232: i++;
233: } while (i<maxit);
234: done:
235: // KSPUnwindPreconditioner
236: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = Diag[idx]*XX[idx]; });
237: team.team_barrier();
238: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
239: int rowa = ic[rowb];
240: glb_x[rowa] = XX[rowb-start];
241: });
242: metad->its = i+1;
243: if (1) {
244: int nnz;
245: parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
246: metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
247: } else {
248: metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
249: }
250: return 0;
251: }
253: // Solve Ax = y with biCG
254: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
255: {
256: using Kokkos::parallel_reduce;
257: using Kokkos::parallel_for;
258: int Nblk = end-start, i;
259: PetscReal dp, r0;
260: PetscScalar *ptr = work_space, dpi, a=1.0, beta, betaold=1.0, b, b2, ma, mac;
261: const PetscScalar *Di = &glb_idiag[start];
262: PetscScalar *XX = ptr; ptr += stride;
263: PetscScalar *Rl = ptr; ptr += stride;
264: PetscScalar *Zl = ptr; ptr += stride;
265: PetscScalar *Pl = ptr; ptr += stride;
266: PetscScalar *Rr = ptr; ptr += stride;
267: PetscScalar *Zr = ptr; ptr += stride;
268: PetscScalar *Pr = ptr; ptr += stride;
270: /* r <- b (x is 0) */
271: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
272: int rowa = ic[rowb];
273: //VecCopy(Rr,Rl);
274: Rl[rowb-start] = Rr[rowb-start] = glb_b[rowa];
275: XX[rowb-start] = 0;
276: });
277: team.team_barrier();
278: /* z <- Br */
279: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx]; });
280: team.team_barrier();
281: /* dp <- r'*r */
282: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
283: team.team_barrier();
284: r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
285: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
286: if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
287: #endif
288: if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
289: if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
290: i = 0;
291: do {
292: /* beta <- r'z */
293: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += Zr[idx]*PetscConj(Rl[idx]);}, beta);
294: team.team_barrier();
295: #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
296: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
297: Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("%7d beta = Z.R = %22.14e \n",i,(double)beta);});
298: #endif
299: #endif
300: if (!i) {
301: if (beta == 0.0) {
302: metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
303: goto done;
304: }
305: /* p <- z */
306: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = Zr[idx]; Pl[idx] = Zl[idx];});
307: } else {
308: b = beta/betaold;
309: /* p <- z + b* p */
310: b2 = PetscConj(b);
311: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = b*Pr[idx] + Zr[idx]; Pl[idx] = b2*Pl[idx] + Zl[idx];});
312: }
313: team.team_barrier();
314: betaold = beta;
315: /* z <- Kp */
316: MatMult (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pr,Zr);
317: MatMultTranspose(team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pl,Zl);
318: /* dpi <- z'p */
319: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Zr[idx]*PetscConj(Pl[idx]);}, dpi);
320: team.team_barrier();
321: //
322: a = beta/dpi; /* a = beta/p'z */
323: ma = -a;
324: mac = PetscConj(ma);
325: /* x <- x + ap */
326: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + a*Pr[idx]; Rr[idx] = Rr[idx] + ma*Zr[idx]; Rl[idx] = Rl[idx] + mac*Zl[idx];});team.team_barrier();
327: team.team_barrier();
328: /* dp <- r'*r */
329: parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
330: team.team_barrier();
331: dp = PetscSqrtReal(PetscRealPart(dpi));
332: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
333: if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dp);});
334: #endif
335: if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
336: if (dp/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
337: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
338: if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dp,r0);}); goto done;}
339: #else
340: if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
341: #endif
342: if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
343: /* z <- Br */
344: parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx];});
345: i++;
346: team.team_barrier();
347: } while (i<maxit);
348: done:
349: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
350: int rowa = ic[rowb];
351: glb_x[rowa] = XX[rowb-start];
352: });
353: metad->its = i+1;
354: if (1) {
355: int nnz;
356: parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
357: metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
358: } else {
359: metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
360: }
361: return 0;
362: }
364: // KSP solver solve Ax = b; x is output, bin is input
365: static PetscErrorCode PCApply_BJKOKKOS(PC pc,Vec bin,Vec xout)
366: {
367: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
368: Mat A = pc->pmat;
369: Mat_SeqAIJKokkos *aijkok;
372: aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
373: if (!aijkok) {
374: SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"No aijkok");
375: } else {
376: using scr_mem_t = Kokkos::DefaultExecutionSpace::scratch_memory_space;
377: using vect2D_scr_t = Kokkos::View<PetscScalar**, Kokkos::LayoutLeft, scr_mem_t>;
378: PetscInt *d_bid_eqOffset, maxit = jac->ksp->max_it, scr_bytes_team, stride, global_buff_size;
379: const PetscInt conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
380: const PetscInt nwork = jac->nwork, nBlk = jac->nBlocks;
381: PetscScalar *glb_xdata=NULL;
382: PetscReal rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
383: const PetscScalar *glb_idiag =jac->d_idiag_k->data(), *glb_bdata=NULL;
384: const PetscInt *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data();
385: const PetscScalar *glb_Aaa = aijkok->a_device_data();
386: Kokkos::View<Batch_MetaData*, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
387: PCFailedReason pcreason;
388: KSPIndex ksp_type_idx = jac->ksp_type_idx;
389: PetscMemType mtype;
390: PetscContainer container;
391: PetscInt batch_sz;
392: VecScatter plex_batch=NULL;
393: Vec bvec;
394: PetscBool monitor = jac->monitor; // captured
395: PetscInt view_bid = jac->batch_target;
396: // get field major is to map plex IO to/from block/field major
397: PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);
398: VecDuplicate(bin,&bvec);
399: if (container) {
400: PetscContainerGetPointer(container, (void **) &plex_batch);
401: VecScatterBegin(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);
402: VecScatterEnd(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);
403: } else {
404: VecCopy(bin, bvec);
405: }
406: // get x
407: VecGetArrayAndMemType(xout,&glb_xdata,&mtype);
408: #if defined(PETSC_HAVE_CUDA)
410: #endif
411: VecGetArrayReadAndMemType(bvec,&glb_bdata,&mtype);
412: #if defined(PETSC_HAVE_CUDA)
414: #endif
415: // get batch size
416: PetscObjectQuery((PetscObject) A, "batch size", (PetscObject *) &container);
417: if (container) {
418: PetscInt *pNf=NULL;
419: PetscContainerGetPointer(container, (void **) &pNf);
420: batch_sz = *pNf;
421: } else batch_sz = 1;
423: d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
424: // solve each block independently
425: if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
426: scr_bytes_team = jac->const_block_size*nwork*sizeof(PetscScalar);
427: stride = jac->const_block_size; // captured
428: global_buff_size = 0;
429: } else {
430: scr_bytes_team = 0;
431: stride = jac->n; // captured
432: global_buff_size = jac->n*nwork;
433: }
434: Kokkos::View<PetscScalar*, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_size); // global work vectors
435: PetscInfo(pc,"\tn = %" PetscInt_FMT ". %d shared mem words/team. %" PetscInt_FMT " global mem words, rtol=%e, num blocks %" PetscInt_FMT ", team_size=%" PetscInt_FMT ", %" PetscInt_FMT " vector threads\n",jac->n,scr_bytes_team/sizeof(PetscScalar),global_buff_size,rtol,nBlk,
436: team_size, PCBJKOKKOS_VEC_SIZE);
437: PetscScalar *d_work_vecs = scr_bytes_team ? NULL : d_work_vecs_k.data();
438: const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
439: Kokkos::parallel_for("Solve", Kokkos::TeamPolicy<>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team)),
440: KOKKOS_LAMBDA (const team_member team) {
441: const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID+1];
442: vect2D_scr_t work_vecs(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), scr_bytes_team ? (end-start) : 0, nwork);
443: PetscScalar *work_buff = (scr_bytes_team) ? work_vecs.data() : &d_work_vecs[start];
444: bool print = monitor && (blkID==view_bid);
445: switch (ksp_type_idx) {
446: case BATCH_KSP_BICG_IDX:
447: BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
448: break;
449: case BATCH_KSP_TFQMR_IDX:
450: BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
451: break;
452: case BATCH_KSP_GMRES_IDX:
453: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
454: printf("GMRES not implemented %d\n",ksp_type_idx);
455: #else
456: /* void */
457: #endif
458: break;
459: default:
460: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
461: printf("Unknown KSP type %d\n",ksp_type_idx);
462: #else
463: /* void */;
464: #endif
465: }
466: });
467: auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
468: Kokkos::fence();
469: Kokkos::deep_copy (h_metadata, d_metadata);
470: #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
471: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
472: PetscPrintf(PETSC_COMM_WORLD,"Iterations\n");
473: #endif
474: // assume species major
475: #if PCBJKOKKOS_VERBOSE_LEVEL < 4
476: PetscPrintf(PETSC_COMM_WORLD,"max iterations per species (%s) :",ksp_type_idx==BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr");
477: #endif
478: for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
479: for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
480: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
481: PetscPrintf(PETSC_COMM_WORLD,"%2D:", s);
482: for (int bid=0 ; bid<batch_sz ; bid++) {
483: PetscPrintf(PETSC_COMM_WORLD,"%3D ", h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its);
484: }
485: PetscPrintf(PETSC_COMM_WORLD,"\n");
486: #else
487: PetscInt count=0;
488: for (int bid=0 ; bid<batch_sz ; bid++) {
489: if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
490: }
491: PetscPrintf(PETSC_COMM_WORLD,"%3D ", count);
492: #endif
493: }
494: head += batch_sz*jac->dm_Nf[dmIdx];
495: }
496: #if PCBJKOKKOS_VERBOSE_LEVEL < 4
497: PetscPrintf(PETSC_COMM_WORLD,"\n");
498: #endif
499: #endif
500: PetscInt count=0, mbid=0;
501: for (int blkID=0;blkID<nBlk;blkID++) {
502: PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops);
503: if (jac->reason) {
504: if (jac->batch_target==blkID) {
505: PetscPrintf(PETSC_COMM_SELF, " Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID%batch_sz, blkID/batch_sz);
506: } else if (jac->batch_target==-1 && h_metadata[blkID].its > count) {
507: count = h_metadata[blkID].its;
508: mbid = blkID;
509: }
510: if (h_metadata[blkID].reason < 0) {
511: PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n",
512: KSPConvergedReasons[h_metadata[blkID].reason],h_metadata[blkID].its,blkID/batch_sz,blkID%batch_sz));
513: }
514: }
515: }
516: if (jac->batch_target==-1 && jac->reason) {
517: PetscPrintf(PETSC_COMM_SELF, " Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[mbid].reason], h_metadata[mbid].its,mbid%batch_sz,mbid/batch_sz);
518: }
519: VecRestoreArrayAndMemType(xout,&glb_xdata);
520: VecRestoreArrayReadAndMemType(bvec,&glb_bdata);
521: {
522: int errsum;
523: Kokkos::parallel_reduce(nBlk, KOKKOS_LAMBDA (const int idx, int& lsum) {
524: if (d_metadata[idx].reason < 0) ++lsum;
525: }, errsum);
526: pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
527: }
528: PCSetFailedReason(pc,pcreason);
529: // map back to Plex space
530: if (plex_batch) {
531: VecCopy(xout, bvec);
532: VecScatterBegin(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);
533: VecScatterEnd(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);
534: }
535: VecDestroy(&bvec);
536: }
537: return 0;
538: }
540: static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
541: {
542: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
543: Mat A = pc->pmat;
544: Mat_SeqAIJKokkos *aijkok;
545: PetscBool flg;
549: PetscObjectTypeCompareAny((PetscObject)A,&flg,MATSEQAIJKOKKOS,MATMPIAIJKOKKOS,MATAIJKOKKOS,"");
551: if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) {
552: SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No aijkok");
553: } else {
554: if (!jac->vec_diag) {
555: Vec *subX;
556: DM pack,*subDM;
557: PetscInt nDMs, n;
558: PetscContainer container;
559: PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);
560: { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
561: MatOrderingType rtype;
562: IS isrow,isicol;
563: const PetscInt *rowindices,*icolindices;
565: if (container) rtype = MATORDERINGNATURAL; // if we have a vecscatter then don't reorder here (all the reorder stuff goes away in future)
566: else rtype = MATORDERINGRCM;
567: // get permutation. Not what I expect so inverted here
568: MatGetOrdering(A,rtype,&isrow,&isicol);
569: ISDestroy(&isrow);
570: ISInvertPermutation(isicol,PETSC_DECIDE,&isrow);
571: ISGetIndices(isrow,&rowindices);
572: ISGetIndices(isicol,&icolindices);
573: const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isrow_k((PetscInt*)rowindices,A->rmap->n);
574: const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isicol_k ((PetscInt*)icolindices,A->rmap->n);
575: jac->d_isrow_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isrow_k));
576: jac->d_isicol_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isicol_k));
577: Kokkos::deep_copy (*jac->d_isrow_k, h_isrow_k);
578: Kokkos::deep_copy (*jac->d_isicol_k, h_isicol_k);
579: ISRestoreIndices(isrow,&rowindices);
580: ISRestoreIndices(isicol,&icolindices);
581: ISDestroy(&isrow);
582: ISDestroy(&isicol);
583: }
584: // get block sizes
585: PCGetDM(pc, &pack);
587: PetscObjectTypeCompare((PetscObject)pack,DMCOMPOSITE,&flg);
589: DMCompositeGetNumberDM(pack,&nDMs);
590: jac->num_dms = nDMs;
591: DMCreateGlobalVector(pack, &jac->vec_diag);
592: VecGetLocalSize(jac->vec_diag,&n);
593: jac->n = n;
594: jac->d_idiag_k = new Kokkos::View<PetscScalar*, Kokkos::LayoutRight>("idiag", n);
595: // options
596: PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
597: KSPSetFromOptions(jac->ksp);
598: PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPBICG,"");
599: if (flg) {jac->ksp_type_idx = BATCH_KSP_BICG_IDX; jac->nwork = 7;}
600: else {
601: PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPTFQMR,"");
602: if (flg) {jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX; jac->nwork = 10;}
603: else {
604: PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPGMRES,"");
605: if (flg) {jac->ksp_type_idx = BATCH_KSP_GMRES_IDX; jac->nwork = 0;}
606: SETERRQ(PetscObjectComm((PetscObject)jac->ksp),PETSC_ERR_ARG_WRONG,"unsupported type %s", ((PetscObject)jac->ksp)->type_name);
607: }
608: }
609: {
610: PetscViewer viewer;
611: PetscBool flg;
612: PetscViewerFormat format;
613: PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_converged_reason",&viewer,&format,&flg);
614: jac->reason = flg;
615: PetscViewerDestroy(&viewer);
616: PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_monitor",&viewer,&format,&flg);
617: jac->monitor = flg;
618: PetscViewerDestroy(&viewer);
619: PetscOptionsGetInt(((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_batch_target",&jac->batch_target,&flg);
621: if (!jac->monitor && !flg) jac->batch_target = -1; // turn it off
622: }
623: // get blocks - jac->d_bid_eqOffset_k
624: PetscMalloc(sizeof(*subX)*nDMs, &subX);
625: PetscMalloc(sizeof(*subDM)*nDMs, &subDM);
626: PetscMalloc(sizeof(*jac->dm_Nf)*nDMs, &jac->dm_Nf);
627: PetscInfo(pc, "Have %" PetscInt_FMT " DMs, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name);
628: DMCompositeGetEntriesArray(pack,subDM);
629: jac->nBlocks = 0;
630: for (PetscInt ii=0;ii<nDMs;ii++) {
631: PetscSection section;
632: PetscInt Nf;
633: DM dm = subDM[ii];
634: DMGetLocalSection(dm, §ion);
635: PetscSectionGetNumFields(section, &Nf);
636: jac->nBlocks += Nf;
637: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
638: if (ii==0) PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks);
639: #else
640: PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks);
641: #endif
642: jac->dm_Nf[ii] = Nf;
643: }
644: { // d_bid_eqOffset_k
645: Kokkos::View<PetscInt*, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks+1);
646: DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX);
647: h_block_offsets[0] = 0;
648: jac->const_block_size = -1;
649: for (PetscInt ii=0, idx = 0;ii<nDMs;ii++) {
650: PetscInt nloc,nblk;
651: VecGetSize(subX[ii],&nloc);
652: nblk = nloc/jac->dm_Nf[ii];
654: for (PetscInt jj=0;jj<jac->dm_Nf[ii];jj++, idx++) {
655: h_block_offsets[idx+1] = h_block_offsets[idx] + nblk;
656: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
657: if (idx==0) PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);
658: #else
659: PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);
660: #endif
661: if (jac->const_block_size == -1) jac->const_block_size = nblk;
662: else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
663: }
664: }
665: DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX);
666: PetscFree(subX);
667: PetscFree(subDM);
668: jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt*, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(),h_block_offsets));
669: Kokkos::deep_copy (*jac->d_bid_eqOffset_k, h_block_offsets);
670: }
671: }
672: { // get jac->d_idiag_k (PC setup),
673: const PetscInt *d_ai=aijkok->i_device_data(), *d_aj=aijkok->j_device_data();
674: const PetscScalar *d_aa = aijkok->a_device_data();
675: const PetscInt conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
676: PetscInt *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
677: PetscScalar *d_idiag = jac->d_idiag_k->data();
678: Kokkos::parallel_for("Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA (const team_member team) {
679: const PetscInt blkID = team.league_rank();
680: Kokkos::parallel_for
681: (Kokkos::TeamThreadRange(team,d_bid_eqOffset[blkID],d_bid_eqOffset[blkID+1]),
682: [=] (const int rowb) {
683: const PetscInt rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
684: const PetscScalar *aa = d_aa + ai;
685: const PetscInt nrow = d_ai[rowa + 1] - ai;
686: int found;
687: Kokkos::parallel_reduce
688: (Kokkos::ThreadVectorRange (team, nrow),
689: [=] (const int& j, int &count) {
690: const PetscInt colb = r[aj[j]];
691: if (colb==rowb) {
692: d_idiag[rowb] = 1./aa[j];
693: count++;
694: }}, found);
695: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
696: if (found!=1) Kokkos::single (Kokkos::PerThread (team), [=] () {printf("ERRORrow %d) found = %d\n",rowb,found);});
697: #endif
698: });
699: });
700: }
701: }
702: return 0;
703: }
705: /* Default destroy, if it has never been setup */
706: static PetscErrorCode PCReset_BJKOKKOS(PC pc)
707: {
708: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
710: KSPDestroy(&jac->ksp);
711: VecDestroy(&jac->vec_diag);
712: if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
713: if (jac->d_idiag_k) delete jac->d_idiag_k;
714: if (jac->d_isrow_k) delete jac->d_isrow_k;
715: if (jac->d_isicol_k) delete jac->d_isicol_k;
716: jac->d_bid_eqOffset_k = NULL;
717: jac->d_idiag_k = NULL;
718: jac->d_isrow_k = NULL;
719: jac->d_isicol_k = NULL;
720: PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",NULL); // not published now (causes configure errors)
721: PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",NULL);
722: PetscFree(jac->dm_Nf);
723: jac->dm_Nf = NULL;
724: return 0;
725: }
727: static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
728: {
729: PCReset_BJKOKKOS(pc);
730: PetscFree(pc->data);
731: return 0;
732: }
734: static PetscErrorCode PCView_BJKOKKOS(PC pc,PetscViewer viewer)
735: {
736: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
737: PetscBool iascii;
739: if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
740: PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);
741: if (iascii) {
742: PetscViewerASCIIPrintf(viewer," Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n");
743: PetscCall(PetscViewerASCIIPrintf(viewer,"\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n",jac->nwork,jac->ksp->rtol,
744: jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
745: ((PetscObject)jac->ksp)->type_name));
746: }
747: return 0;
748: }
750: static PetscErrorCode PCSetFromOptions_BJKOKKOS(PetscOptionItems *PetscOptionsObject,PC pc)
751: {
752: PetscOptionsHead(PetscOptionsObject,"PC BJKOKKOS options");
753: PetscOptionsTail();
754: return 0;
755: }
757: static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc,KSP ksp)
758: {
759: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
761: PetscObjectReference((PetscObject)ksp);
762: KSPDestroy(&jac->ksp);
763: jac->ksp = ksp;
764: return 0;
765: }
767: /*@C
768: PCBJKOKKOSSetKSP - Sets the KSP context for a KSP PC.
770: Collective on PC
772: Input Parameters:
773: + pc - the preconditioner context
774: - ksp - the KSP solver
776: Notes:
777: The PC and the KSP must have the same communicator
779: Level: advanced
781: @*/
782: PetscErrorCode PCBJKOKKOSSetKSP(PC pc,KSP ksp)
783: {
787: PetscTryMethod(pc,"PCBJKOKKOSSetKSP_C",(PC,KSP),(pc,ksp));
788: return 0;
789: }
791: static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc,KSP *ksp)
792: {
793: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
795: if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
796: *ksp = jac->ksp;
797: return 0;
798: }
800: /*@C
801: PCBJKOKKOSGetKSP - Gets the KSP context for a KSP PC.
803: Not Collective but KSP returned is parallel if PC was parallel
805: Input Parameter:
806: . pc - the preconditioner context
808: Output Parameters:
809: . ksp - the KSP solver
811: Notes:
812: You must call KSPSetUp() before calling PCBJKOKKOSGetKSP().
814: If the PC is not a PCBJKOKKOS object it raises an error
816: Level: advanced
818: @*/
819: PetscErrorCode PCBJKOKKOSGetKSP(PC pc,KSP *ksp)
820: {
823: PetscUseMethod(pc,"PCBJKOKKOSGetKSP_C",(PC,KSP*),(pc,ksp));
824: return 0;
825: }
827: /* ----------------------------------------------------------------------------------*/
829: /*MC
830: PCBJKOKKOS - Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a AIJASeq matrix on the GPU.
832: Options Database Key:
833: . -pc_bjkokkos_ - options prefix with ksp options
835: Level: intermediate
837: Notes:
838: For use with -ksp_type preonly to bypass any CPU work
840: Developer Notes:
842: .seealso: PCCreate(), PCSetType(), PCType (for list of available types), PC,
843: PCSHELL, PCCOMPOSITE, PCSetUseAmat(), PCBJKOKKOSGetKSP()
845: M*/
847: PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
848: {
849: PC_PCBJKOKKOS *jac;
851: PetscNewLog(pc,&jac);
852: pc->data = (void*)jac;
854: jac->ksp = NULL;
855: jac->vec_diag = NULL;
856: jac->d_bid_eqOffset_k = NULL;
857: jac->d_idiag_k = NULL;
858: jac->d_isrow_k = NULL;
859: jac->d_isicol_k = NULL;
860: jac->nBlocks = 1;
862: PetscMemzero(pc->ops,sizeof(struct _PCOps));
863: pc->ops->apply = PCApply_BJKOKKOS;
864: pc->ops->applytranspose = NULL;
865: pc->ops->setup = PCSetUp_BJKOKKOS;
866: pc->ops->reset = PCReset_BJKOKKOS;
867: pc->ops->destroy = PCDestroy_BJKOKKOS;
868: pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
869: pc->ops->view = PCView_BJKOKKOS;
871: PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",PCBJKOKKOSGetKSP_BJKOKKOS);
872: PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",PCBJKOKKOSSetKSP_BJKOKKOS);
873: return 0;
874: }