Skip to content

Commit

Permalink
added template function trmm for all relevant types
Browse files Browse the repository at this point in the history
  • Loading branch information
ActiveAnalytics committed Apr 6, 2017
1 parent e3aa4cf commit 48ee11a
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 24 deletions.
37 changes: 37 additions & 0 deletions source/dblas/examples.d
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,43 @@ import std.complex: Complex, complex;
** to make sure that they are running properly.
*/


/* C function for trmm */
extern (C){
void cblas_dtrmm (in CBLAS_ORDER Order, in CBLAS_SIDE Side, in CBLAS_UPLO Uplo, in CBLAS_TRANSPOSE TransA,
in CBLAS_DIAG Diag, in int M, in int N, in double alpha, in double* A, in int lda, double* B, in int ldb);
}

/* Testing for trmm */
void test_trmm(){

CBLAS_ORDER order = CblasRowMajor;
CBLAS_SIDE side = CblasLeft;
CBLAS_UPLO uplo = CblasUpper;
CBLAS_TRANSPOSE trans = CblasNoTrans;
CBLAS_DIAG diag = CblasNonUnit;

int m = 2;
int n = 3;
double alpha = -0.3;
double[] a = [0.174, -0.308, 0.997, -0.484];
double[] b = [-0.256, -0.178, 0.098, 0.004, 0.97, -0.408];
int lda = 2, ldb = 3;
//double B_expected[] = { 0.0137328, 0.0989196, -0.0428148, 5.808e-04, 0.140844, -0.0592416 };
trmm(order, side, uplo, trans, diag, m, n, alpha, a.ptr, lda, b.ptr, ldb);
writeln("trmm: ", b);

Complex!double alphac = complex(0, 0);
Complex!double[] ac = [complex(0.463, 0.033), complex(-0.929, 0.949), complex(0.864, 0.986), complex(0.393, 0.885)];
Complex!double[] bc = [complex(-0.321, -0.852), complex(-0.337, -0.175), complex(0.607, -0.613),
complex(0.688, 0.973), complex(-0.331, -0.35), complex(0.719, -0.553)];
//Complex!double[] B_expected = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
trmm(order, side, uplo, trans, diag, m, n, alphac, ac.ptr, lda, bc.ptr, ldb);
writeln("trmm (Complex!double): ", bc);
}



/* C function for syr2k */
extern (C){
void cblas_dsyr2k (in CBLAS_ORDER Order, in CBLAS_UPLO Uplo, in CBLAS_TRANSPOSE Trans, in int N, in int K,
Expand Down
256 changes: 256 additions & 0 deletions source/dblas/l3.d
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,262 @@ void syr2k(N, X)(in CBLAS_ORDER order, in CBLAS_UPLO uplo, in CBLAS_TRANSPOSE tr



/**
* @title trmm Computes a matrix-matrix product where one input matrix is triangular.
*
* @description The trmm routines compute a scalar-matrix-matrix product with one triangular matrix. The operation is
* defined as:
* B := alpha*op(A)*B,
* or
* B := alpha*B*op(A),
* where
*
* alpha is a scalar,
* B is an m-by-n matrix,
* A is a unit, or non-unit, upper or lower triangular matrix
* op(A) is one of op(A) = A, or op(A) = A' , or op(A) = conjg(A') .
*
* Input Parameters:
*
* @param order: Specifies whether two-dimensional array storage is row-major
* (CblasRowMajor) or column-major (CblasColMajor).
*
* @param side: Specifies whether op(A) appears on the left or right of B in the operation:
* if side = CblasLeft , then B := alpha*op(A)*B ;
* if side = CblasRight , then B := alpha*B*op(A) .
* @param uplo: Specifies whether the matrix A is upper or lower triangular:
* uplo = CblasUpper
* if uplo = CblasLower , then the matrix is low triangular.
* @param transa: Specifies the form of op(A) used in the matrix multiplication:
* if transa = CblasNoTrans , then op(A) = A ;
* if transa = CblasTrans , then op(A) = A' ;
* if transa = CblasConjTrans , then op(A) = conjg(A') .
* @param diag: Specifies whether the matrix A is unit triangular:
* if diag = CblasUnit then the matrix is unit triangular;
* if diag = CblasNonUnit , then the matrix is not unit triangular.
* @param m: Specifies the number of rows of B. The value of m must be at least zero.
* @param n: Specifies the number of columns of B. The value of n must be at least zero.
* @param alpha: Specifies the scalar alpha.
* When alpha is zero, then a is not referenced and b need not be set before
* entry.
* @param a: Array, size lda by k, where k is m when side = CblasLeft and is n when
* side = CblasRight . Before entry with uplo = CblasUpper , the leading k
* by k upper triangular part of the array a must contain the upper triangular
* matrix and the strictly lower triangular part of a is not referenced.
* Before entry with uplo = CblasLower , the leading k by k lower triangular
* part of the array a must contain the lower triangular matrix and the strictly
* upper triangular part of a is not referenced.
* When diag = CblasUnit , the diagonal elements of a are not referenced
* either, but are assumed to be unity.
* @param lda: Specifies the leading dimension of a as declared in the calling
* (sub)program. When side = CblasLeft , then lda must be at least max(1,
* m) , when side = CblasRight , then lda must be at least max(1, n) .
* @param b: For Layout = CblasColMajor : array, size ldb*n . Before entry, the leading
* m-by-n part of the array b must contain the matrix B.
* For Layout = CblasRowMajor : array, size ldb*m . Before entry, the leading
* n-by-m part of the array b must contain the matrix B.
* @param ldb: Specifies the leading dimension of b as declared in the calling
* (sub)program. When Layout = CblasColMajor , ldb must be at least
* max(1, m); otherwise, ldb must be at least max(1, n).
*
* Output Parameters:
*
* @param b: Overwritten by the transformed matrix.
*
*/
void trmm(N, X)(in CBLAS_ORDER order, in CBLAS_SIDE side, in CBLAS_UPLO uplo, in CBLAS_TRANSPOSE transA,
in CBLAS_DIAG diag, in N m, in N n, in X alpha, in X* a, in N lda, X* b, in N ldb)
{
N i, j, k;
N n1, n2;
X zero = X(0), one = X(1);

const N nonunit = (diag == CblasNonUnit);
N side_, uplo_, trans;

if (order == CblasRowMajor) {
n1 = m;
n2 = n;
side_ = side;
uplo_ = uplo;
trans = (transA == CblasConjTrans) ? CblasTrans : transA;
} else {
n1 = n;
n2 = m;
side_ = (side == CblasLeft) ? CblasRight : CblasLeft;
uplo_ = (uplo == CblasUpper) ? CblasLower : CblasUpper;
trans = (transA == CblasConjTrans) ? CblasTrans : transA;
}

if (side_ == CblasLeft && uplo_ == CblasUpper && trans == CblasNoTrans) {
/* form B := alpha * TriU(A)*B */
for (i = 0; i < n1; i++) {
for (j = 0; j < n2; j++) {
X temp = zero;

if (nonunit) {
temp = a[i * lda + i] * b[i * ldb + j];
} else {
temp = b[i * ldb + j];
}

for (k = i + 1; k < n1; k++) {
temp += a[lda * i + k] * b[k * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}
} else if (side_ == CblasLeft && uplo_ == CblasUpper && trans == CblasTrans) {
/* form B := alpha * (TriU(A))' *B */
for (i = n1; i > 0 && i--;) {
for (j = 0; j < n2; j++) {
X temp = zero;

for (k = 0; k < i; k++) {
temp += a[lda * k + i] * b[k * ldb + j];
}

if (nonunit) {
temp += a[i * lda + i] * b[i * ldb + j];
} else {
temp += b[i * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}
} else if (side_ == CblasLeft && uplo_ == CblasLower && trans == CblasNoTrans) {

/* form B := alpha * TriL(A)*B */

for (i = n1; i > 0 && i--;) {
for (j = 0; j < n2; j++) {
X temp = zero;

for (k = 0; k < i; k++) {
temp += a[lda * i + k] * b[k * ldb + j];
}

if (nonunit) {
temp += a[i * lda + i] * b[i * ldb + j];
} else {
temp += b[i * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}

} else if (side_ == CblasLeft && uplo_ == CblasLower && trans == CblasTrans) {
/* form B := alpha * TriL(A)' *B */
for (i = 0; i < n1; i++) {
for (j = 0; j < n2; j++) {
X temp = zero;

if (nonunit) {
temp = a[i * lda + i] * b[i * ldb + j];
} else {
temp = b[i * ldb + j];
}

for (k = i + 1; k < n1; k++) {
temp += a[lda * k + i] * b[k * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}
} else if (side_ == CblasRight && uplo_ == CblasUpper && trans == CblasNoTrans) {

/* form B := alpha * B * TriU(A) */

for (i = 0; i < n1; i++) {
for (j = n2; j > 0 && j--;) {
X temp = zero;

for (k = 0; k < j; k++) {
temp += a[lda * k + j] * b[i * ldb + k];
}

if (nonunit) {
temp += a[j * lda + j] * b[i * ldb + j];
} else {
temp += b[i * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}

} else if (side_ == CblasRight && uplo_ == CblasUpper && trans == CblasTrans) {

/* form B := alpha * B * (TriU(A))' */

for (i = 0; i < n1; i++) {
for (j = 0; j < n2; j++) {
X temp = zero;

if (nonunit) {
temp = a[j * lda + j] * b[i * ldb + j];
} else {
temp = b[i * ldb + j];
}

for (k = j + 1; k < n2; k++) {
temp += a[lda * j + k] * b[i * ldb + k];
}

b[ldb * i + j] = alpha * temp;
}
}

} else if (side_ == CblasRight && uplo_ == CblasLower && trans == CblasNoTrans) {
/* form B := alpha *B * TriL(A) */
for (i = 0; i < n1; i++) {
for (j = 0; j < n2; j++) {
X temp = zero;

if (nonunit) {
temp = a[j * lda + j] * b[i * ldb + j];
} else {
temp = b[i * ldb + j];
}

for (k = j + 1; k < n2; k++) {
temp += a[lda * k + j] * b[i * ldb + k];
}

b[ldb * i + j] = alpha * temp;
}
}
} else if (side_ == CblasRight && uplo_ == CblasLower && trans == CblasTrans) {
/* form B := alpha * B * TriL(A)' */
for (i = 0; i < n1; i++) {
for (j = n2; j > 0 && j--;) {
X temp = zero;

for (k = 0; k < j; k++) {
temp += a[lda * j + k] * b[i * ldb + k];
}

if (nonunit) {
temp += a[j * lda + j] * b[i * ldb + j];
} else {
temp += b[i * ldb + j];
}

b[ldb * i + j] = alpha * temp;
}
}
} else {
assert("unrecognized operation");
}
}







46 changes: 22 additions & 24 deletions source/dblas/package.d
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,37 @@ import std.complex: Complex, complex;
/* To compile: */
/* dub build dblas # or dub run ... */

/* C function for syr2k */
/* C function for trmm */
extern (C){
void cblas_dsyr2k (in CBLAS_ORDER Order, in CBLAS_UPLO Uplo, in CBLAS_TRANSPOSE Trans, in int N, in int K,
in double alpha, in double *A, in int lda, in double *B, in int ldb, in double beta, double *C, in int ldc);
void cblas_dtrmm (in CBLAS_ORDER Order, in CBLAS_SIDE Side, in CBLAS_UPLO Uplo, in CBLAS_TRANSPOSE TransA,
in CBLAS_DIAG Diag, in int M, in int N, in double alpha, in double* A, in int lda, double* B, in int ldb);
}

/* Testing for syr2k */
/* Testing for trmm */
void main(){

CBLAS_ORDER order = CblasRowMajor;
CBLAS_SIDE side = CblasLeft;
CBLAS_UPLO uplo = CblasUpper;
CBLAS_TRANSPOSE trans = CblasNoTrans;
int n = 1;
int k = 2;
double alpha = 0.1;
double beta = 0;
double[] a = [-0.225, 0.857];
double[] b = [-0.933, 0.994];
double[] c = [0.177];
int lda = 2, ldb = 2, ldc = 1;

//double C_expected[] = { 0.2123566 };
syr2k(order, uplo, trans, n, k, alpha, a.ptr, lda, b.ptr, ldb, beta, c.ptr, ldc);
writeln("syr2k: ", c);
CBLAS_DIAG diag = CblasNonUnit;

int m = 2;
int n = 3;
double alpha = -0.3;
double[] a = [0.174, -0.308, 0.997, -0.484];
double[] b = [-0.256, -0.178, 0.098, 0.004, 0.97, -0.408];
int lda = 2, ldb = 3;
//double B_expected[] = { 0.0137328, 0.0989196, -0.0428148, 5.808e-04, 0.140844, -0.0592416 };
trmm(order, side, uplo, trans, diag, m, n, alpha, a.ptr, lda, b.ptr, ldb);
writeln("trmm: ", b);

Complex!double alphac = complex(0, 0);
Complex!double betac = complex(-0.3, 0.1);
Complex!double[] ac = [complex(-0.315, 0.03), complex(0.281, 0.175)];
Complex!double[] bc = [complex(-0.832, -0.964), complex(0.291, 0.476)];
Complex!double[] cc = [complex(-0.341, 0.743)];
//double C_expected[] = { 0.028, -0.257 };

syr2k(order, uplo, trans, n, k, alphac, ac.ptr, lda, bc.ptr, ldb, betac, cc.ptr, ldc);
writeln("syr2k (Complex!double): ", cc);
Complex!double[] ac = [complex(0.463, 0.033), complex(-0.929, 0.949), complex(0.864, 0.986), complex(0.393, 0.885)];
Complex!double[] bc = [complex(-0.321, -0.852), complex(-0.337, -0.175), complex(0.607, -0.613),
complex(0.688, 0.973), complex(-0.331, -0.35), complex(0.719, -0.553)];
//Complex!double[] B_expected = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
trmm(order, side, uplo, trans, diag, m, n, alphac, ac.ptr, lda, bc.ptr, ldb);
writeln("trmm (Complex!double): ", bc);
}

0 comments on commit 48ee11a

Please sign in to comment.