Actual source code: baijsolvtrann.c

  1: #include <../src/mat/impls/baij/seq/baij.h>
  2: #include <petsc/private/kernels/blockinvert.h>

  4: /* ----------------------------------------------------------- */
  5: PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx)
  6: {
  7:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
  8:   IS                iscol=a->col,isrow=a->row;
  9:   const PetscInt    *r,*c,*rout,*cout,*ai=a->i,*aj=a->j,*vi;
 10:   PetscInt          i,nz,j;
 11:   const PetscInt    n  =a->mbs,bs=A->rmap->bs,bs2=a->bs2;
 12:   const MatScalar   *aa=a->a,*v;
 13:   PetscScalar       *x,*t,*ls;
 14:   const PetscScalar *b;

 16:   VecGetArrayRead(bb,&b);
 17:   VecGetArray(xx,&x);
 18:   t    = a->solve_work;

 20:   ISGetIndices(isrow,&rout); r = rout;
 21:   ISGetIndices(iscol,&cout); c = cout;

 23:   /* copy the b into temp work space according to permutation */
 24:   for (i=0; i<n; i++) {
 25:     for (j=0; j<bs; j++) {
 26:       t[i*bs+j] = b[c[i]*bs+j];
 27:     }
 28:   }

 30:   /* forward solve the upper triangular transpose */
 31:   ls = a->solve_work + A->cmap->n;
 32:   for (i=0; i<n; i++) {
 33:     PetscArraycpy(ls,t+i*bs,bs);
 34:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*a->diag[i],t+i*bs);
 35:     v  = aa + bs2*(a->diag[i] + 1);
 36:     vi = aj + a->diag[i] + 1;
 37:     nz = ai[i+1] - a->diag[i] - 1;
 38:     while (nz--) {
 39:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 40:       v += bs2;
 41:     }
 42:   }

 44:   /* backward solve the lower triangular transpose */
 45:   for (i=n-1; i>=0; i--) {
 46:     v  = aa + bs2*ai[i];
 47:     vi = aj + ai[i];
 48:     nz = a->diag[i] - ai[i];
 49:     while (nz--) {
 50:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 51:       v += bs2;
 52:     }
 53:   }

 55:   /* copy t into x according to permutation */
 56:   for (i=0; i<n; i++) {
 57:     for (j=0; j<bs; j++) {
 58:       x[bs*r[i]+j]   = t[bs*i+j];
 59:     }
 60:   }

 62:   ISRestoreIndices(isrow,&rout);
 63:   ISRestoreIndices(iscol,&cout);
 64:   VecRestoreArrayRead(bb,&b);
 65:   VecRestoreArray(xx,&x);
 66:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
 67:   return 0;
 68: }

 70: PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A,Vec bb,Vec xx)
 71: {
 72:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
 73:   IS                iscol=a->col,isrow=a->row;
 74:   const PetscInt    *r,*c,*rout,*cout;
 75:   const PetscInt    n=a->mbs,*ai=a->i,*aj=a->j,*vi,*diag=a->diag;
 76:   PetscInt          i,j,nz;
 77:   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
 78:   const MatScalar   *aa=a->a,*v;
 79:   PetscScalar       *x,*t,*ls;
 80:   const PetscScalar *b;

 82:   VecGetArrayRead(bb,&b);
 83:   VecGetArray(xx,&x);
 84:   t    = a->solve_work;

 86:   ISGetIndices(isrow,&rout); r = rout;
 87:   ISGetIndices(iscol,&cout); c = cout;

 89:   /* copy the b into temp work space according to permutation */
 90:   for (i=0; i<n; i++) {
 91:     for (j=0; j<bs; j++) {
 92:       t[i*bs+j] = b[c[i]*bs+j];
 93:     }
 94:   }

 96:   /* forward solve the upper triangular transpose */
 97:   ls = a->solve_work + A->cmap->n;
 98:   for (i=0; i<n; i++) {
 99:     PetscArraycpy(ls,t+i*bs,bs);
100:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*diag[i],t+i*bs);
101:     v  = aa + bs2*(diag[i] - 1);
102:     vi = aj + diag[i] - 1;
103:     nz = diag[i] - diag[i+1] - 1;
104:     for (j=0; j>-nz; j--) {
105:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
106:       v -= bs2;
107:     }
108:   }

110:   /* backward solve the lower triangular transpose */
111:   for (i=n-1; i>=0; i--) {
112:     v  = aa + bs2*ai[i];
113:     vi = aj + ai[i];
114:     nz = ai[i+1] - ai[i];
115:     for (j=0; j<nz; j++) {
116:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
117:       v += bs2;
118:     }
119:   }

121:   /* copy t into x according to permutation */
122:   for (i=0; i<n; i++) {
123:     for (j=0; j<bs; j++) {
124:       x[bs*r[i]+j]   = t[bs*i+j];
125:     }
126:   }

128:   ISRestoreIndices(isrow,&rout);
129:   ISRestoreIndices(iscol,&cout);
130:   VecRestoreArrayRead(bb,&b);
131:   VecRestoreArray(xx,&x);
132:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
133:   return 0;
134: }