Actual source code: mpimattransposematmult.c
2: /*
3: Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4: C = A^T * B
5: The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6: */
7: #include <../src/mat/impls/aij/seq/aij.h>
8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
9: #include <../src/mat/impls/dense/mpi/mpidense.h>
11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
12: {
13: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult*)data;
15: MatDestroy(&atb->mA);
16: VecDestroy(&atb->bt);
17: VecDestroy(&atb->ct);
18: PetscFree(atb);
19: return 0;
20: }
22: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat,Mat,Mat);
24: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)
25: {
26: Mat_MatTransMatMult *atb;
27: PetscBool cisdense;
29: MatCheckProduct(C,4);
32: /* create output dense matrix C = A^T*B */
33: MatSetSizes(C,A->cmap->n,B->cmap->n,A->cmap->N,B->cmap->N);
34: PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATMPIDENSE,MATMPIDENSECUDA,"");
35: if (!cisdense) {
36: MatSetType(C,((PetscObject)B)->type_name);
37: }
38: MatSetUp(C);
40: /* create additional data structure for the product */
41: PetscNew(&atb);
42: if (B->cmap->N) {
43: MatCreateMAIJ(A,B->cmap->N,&atb->mA);
44: if (!atb->mA->assembled) {
45: MatAssemblyBegin(atb->mA,MAT_FINAL_ASSEMBLY);
46: MatAssemblyEnd(atb->mA,MAT_FINAL_ASSEMBLY);
47: }
48: MatCreateVecs(atb->mA,&atb->ct,&atb->bt);
49: }
50: C->product->data = atb;
51: C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
53: C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
54: return 0;
55: }
57: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)
58: {
59: const PetscScalar *Barray,*ctarray;
60: PetscScalar *Carray,*btarray;
61: PetscInt i,j,m=A->rmap->n,n=A->cmap->n,ldb,BN=B->cmap->N,ldc;
62: Mat_MatTransMatMult *atb;
63: Vec bt,ct;
65: MatCheckProduct(C,3);
66: atb = (Mat_MatTransMatMult *)C->product->data;
68: if (!BN) {
69: MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);
70: MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);
71: return 0;
72: }
73: bt = atb->bt;
74: ct = atb->ct;
76: /* transpose local array of B, then copy it to vector bt */
77: MatDenseGetArrayRead(B,&Barray);
78: MatDenseGetLDA(B,&ldb);
79: VecGetArray(bt,&btarray);
80: for (j=0; j<BN; j++)
81: for (i=0; i<m; i++)
82: btarray[i*BN + j] = Barray[j*ldb + i];
83: VecRestoreArray(bt,&btarray);
84: MatDenseRestoreArrayRead(B,&Barray);
86: /* compute ct = mA^T * cb */
87: MatMultTranspose(atb->mA,bt,ct);
89: /* transpose local array of ct to matrix C */
90: MatDenseGetArray(C,&Carray);
91: MatDenseGetLDA(C,&ldc);
92: VecGetArrayRead(ct,&ctarray);
93: for (j=0; j<BN; j++)
94: for (i=0; i<n; i++)
95: Carray[j*ldc + i] = ctarray[i*BN + j];
96: VecRestoreArrayRead(ct,&ctarray);
97: MatDenseRestoreArray(C,&Carray);
98: MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);
99: MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);
100: return 0;
101: }