Skip to content

Commit

Permalink
Adding implementations of sparse-direct LeastSquares/Ridge/Tikhonov f…
Browse files Browse the repository at this point in the history
…or the case where width(A) > height(A).
  • Loading branch information
Jack Poulson committed Nov 8, 2014
1 parent 2cd9eb3 commit 5dc20f1
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 28 deletions.
55 changes: 41 additions & 14 deletions src/sparse_direct/numeric/LeastSquares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,59 @@ void LeastSquares
if( orientation != NORMAL && A.Width() != Y.Height() )
LogicError("Width of A and height of Y must match");
)
if( A.Width() > A.Height() )
LogicError("LeastSquares currently assumes height(A) >= width(A)");
const Int m = A.Height();
const Int n = A.Width();
DistSparseMatrix<F> C(A.Comm());
X.SetComm( Y.Comm() );
if( orientation == NORMAL )
{
const Int n = A.Width();
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
MakeHermitian( LOWER, C );
X.SetComm( Y.Comm() );
Zeros( X, n, Y.Width() );
Multiply( ADJOINT, F(1), A, Y, F(0), X );
if( m >= n )
{
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
MakeHermitian( LOWER, C );

Multiply( ADJOINT, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
}
else
{
Herk( LOWER, NORMAL, Base<F>(1), A, C );
MakeHermitian( LOWER, C );

DistMultiVec<F> YCopy(Y.Comm());
YCopy = Y;
HermitianSolve( C, YCopy, ctrl );
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
}
}
else if( orientation == ADJOINT || !IsComplex<F>::val )
{
const Int n = A.Height();
Herk( LOWER, NORMAL, Base<F>(1), A, C );
MakeHermitian( LOWER, C );
X.SetComm( Y.Comm() );
Zeros( X, n, Y.Width() );
Multiply( NORMAL, F(1), A, Y, F(0), X );
Zeros( X, m, Y.Width() );
if( m >= n )
{
Herk( LOWER, NORMAL, Base<F>(1), A, C );
MakeHermitian( LOWER, C );

Multiply( NORMAL, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
}
else
{
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
MakeHermitian( LOWER, C );

DistMultiVec<F> YCopy(Y.Comm());
YCopy = Y;
HermitianSolve( C, YCopy, ctrl );
Multiply( NORMAL, F(1), A, YCopy, F(0), X );
}
}
else
{
LogicError("Complex transposed option not yet supported");
}
HermitianSolve( C, X, ctrl );

}

#define PROTO(F) \
Expand Down
29 changes: 22 additions & 7 deletions src/sparse_direct/numeric/Ridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,32 @@ void Ridge
if( A.Height() != Y.Height() )
LogicError("Heights of A and Y must match");
)
if( A.Width() > A.Height() )
LogicError("Ridge currently assumes height(A) >= width(A)");
const Int m = A.Height();
const Int n = A.Width();
DistSparseMatrix<F> C(A.Comm());
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
UpdateDiagonal( C, F(alpha*alpha) );
MakeHermitian( LOWER, C );

X.SetComm( Y.Comm() );
Zeros( X, n, Y.Width() );
Multiply( ADJOINT, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
if( m >= n )
{
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
UpdateDiagonal( C, F(alpha*alpha) );
MakeHermitian( LOWER, C );

Multiply( ADJOINT, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
}
else
{
Herk( LOWER, NORMAL, Base<F>(1), A, C );
UpdateDiagonal( C, F(alpha*alpha) );
MakeHermitian( LOWER, C );

DistMultiVec<F> YCopy(Y.Comm());
YCopy = Y;
HermitianSolve( C, YCopy, ctrl );
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
}
}

#define PROTO(F) \
Expand Down
29 changes: 22 additions & 7 deletions src/sparse_direct/numeric/Tikhonov.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,32 @@ void Tikhonov
if( A.Height() != Y.Height() )
LogicError("Heights of A and Y must match");
)
if( A.Width() > A.Height() )
LogicError("Tikhonov currently assumes height(A) >= width(A)");
const Int m = A.Height();
const Int n = A.Width();
DistSparseMatrix<F> C(A.Comm());
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
Herk( LOWER, ADJOINT, Base<F>(1), Gamma, Base<F>(1), C );
MakeHermitian( LOWER, C );

X.SetComm( Y.Comm() );
Zeros( X, n, Y.Width() );
Multiply( ADJOINT, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
if( m >= n )
{
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
Herk( LOWER, ADJOINT, Base<F>(1), Gamma, Base<F>(1), C );
MakeHermitian( LOWER, C );

Multiply( ADJOINT, F(1), A, Y, F(0), X );
HermitianSolve( C, X, ctrl );
}
else
{
Herk( LOWER, NORMAL, Base<F>(1), A, C );
Herk( LOWER, NORMAL, Base<F>(1), Gamma, Base<F>(1), C );
MakeHermitian( LOWER, C );

DistMultiVec<F> YCopy(Y.Comm());
YCopy = Y;
HermitianSolve( C, YCopy, ctrl );
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
}
}

#define PROTO(F) \
Expand Down

0 comments on commit 5dc20f1

Please sign in to comment.