Skip to content

Commit

Permalink
fixed fromRowsV
Browse files Browse the repository at this point in the history
  • Loading branch information
ocramz committed Apr 27, 2018
1 parent d4b8ee3 commit 67cef26
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 45 deletions.
91 changes: 53 additions & 38 deletions src/Control/Iterative.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module Control.Iterative where

import Control.Applicative

import Control.Monad (when)
import Control.Monad (when, replicateM)
import Control.Monad.Reader (MonadReader(..), asks)
import Control.Monad.State.Strict (MonadState(..), get, put)
import Control.Monad.Trans.Class (MonadTrans(..), lift)
Expand Down Expand Up @@ -65,6 +65,16 @@ runIterativeT :: Handler m message -> IterativeT r message s m a -> r -> s -> m
runIterativeT lh m c x0 =
runLoggingT (runStateT (runReaderT (unIterativeT m) c) x0) lh

runNIterativeT :: Monad m =>
Handler m message
-> IterativeT r message s m a
-> r
-> Int -- ^ # of iterations
-> s
-> m ([a], s)
runNIterativeT lh m c n x0 =
runLoggingT (runStateT (replicateM n (runReaderT (unIterativeT m) c)) x0) lh

mkIterativeT :: Monad m =>
(a -> s -> message)
-> (s -> r -> (a, s))
Expand All @@ -78,54 +88,59 @@ mkIterativeT flog fs = IterativeT $ do
return a


-- | Configuration data and diagnostic functions for the iterative process
-- | Configuration data for the iterative process
data IterConfig s t = IterConfig {
icFunctionName :: String -- ^ Name of calling function, for logging purposes
, icLogLevelConvergence :: Maybe Severity
, icLogLevelDivergence :: Maybe Severity
, icNumIterationsMax :: Int -- ^ Max # of iterations
, icStateWindowLength :: Int -- ^ # of states used to assess convergence/divergence
, icStateSummary :: [s] -> t -- ^ Produce a summary from a list of states
} deriving (Eq, Show)

-- | Diagnostic functions for the iterative process
data IterDiagnostics s t = IterDiagnostics {
icStateSummary :: [s] -> t -- ^ Produce a summary from a list of states
, icStateConverging :: t -> Bool
, icStateDiverging :: t -> t -> Bool
, icStateFinal :: s -> Bool
}

-- modifyInspectGuardedM_IterT :: MonadThrow m => IterativeT ()
modifyInspectGuardedM_IterT itc@(IterConfig fname llconv lldiv nitermax lwindow sf qconverg qdiverg qfinal) lh f x0 =
when (nitermax <= 0) $ throwM (NonNegError fname nitermax)
runIterativeT lh (go 0 []) itc x0
where
checkConvergStatus y i ll
| length ll < lwindow = BufferNotReady
| qdiverg qi qt && not (qconverg qi) = Diverging qi qt
| qconverg qi || qfinal (pf y) = Converged qi
| i == nitermax - 1 = NotConverged
| otherwise = Converging
where llf = pf <$> ll
qi = sf $ init llf -- summary of latest 2 states
qt = sf $ tail llf -- " " previous 2 states
go i ll = do
x <- get
y <- lift $ f x
-- when (printDebugInfo config) $ do
-- logMessage $ unwords ["Iteration", show i]
case checkConvergStatus y i ll of
BufferNotReady -> do
put y
let ll' = y : ll -- cons current state to buffer
go (i + 1) ll'
Converged qi -> put y
Diverging qi qt -> do
put y
throwM (DivergingE fname i qi qt)
Converging -> do
put y
let ll' = init (y : ll) -- rolling state window
go (i + 1) ll'
NotConverged -> do
put y
throwM (NotConvergedE fname nitermax y)
-- -- modifyInspectGuardedM_IterT :: MonadThrow m => IterativeT ()
-- modifyInspectGuardedM_IterT itc@(IterConfig fname llconv lldiv nitermax lwindow sf qconverg qdiverg qfinal) lh f x0 =
-- when (nitermax <= 0) $ throwM (NonNegError fname nitermax)
-- runIterativeT lh (go 0 []) itc x0
-- where
-- checkConvergStatus y i ll
-- | length ll < lwindow = BufferNotReady
-- | qdiverg qi qt && not (qconverg qi) = Diverging qi qt
-- | qconverg qi || qfinal (pf y) = Converged qi
-- | i == nitermax - 1 = NotConverged
-- | otherwise = Converging
-- where llf = pf <$> ll
-- qi = sf $ init llf -- summary of latest 2 states
-- qt = sf $ tail llf -- " " previous 2 states
-- go i ll = do
-- x <- get
-- y <- lift $ f x
-- -- when (printDebugInfo config) $ do
-- -- logMessage $ unwords ["Iteration", show i]
-- case checkConvergStatus y i ll of
-- BufferNotReady -> do
-- put y
-- let ll' = y : ll -- cons current state to buffer
-- go (i + 1) ll'
-- Converged qi -> put y
-- Diverging qi qt -> do
-- put y
-- throwM (DivergingE fname i qi qt)
-- Converging -> do
-- put y
-- let ll' = init (y : ll) -- rolling state window
-- go (i + 1) ll'
-- NotConverged -> do
-- put y
-- throwM (NotConvergedE fname nitermax y)




Expand Down
14 changes: 7 additions & 7 deletions src/Data/Sparse/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ mapKeysSV fk (SV d sv) = SV d $ I.mapKeys fk sv

-- | Insert row , using the provided row index transformation function
insertRowWith :: (IxCol -> IxCol) -> SpMatrix a -> SpVector a -> IM.Key -> SpMatrix a
insertRowWith fj (SM (m,n) im) (SV d sv) i
| not (inBounds0 m i) = error "insertRowSM : index out of bounds"
insertRowWith fj (SM (m, n) im) (SV d sv) i
| not (inBounds0 m i) = error "insertRowWith : index out of bounds"
| n >= d = SM (m,n) $ I.insert i (insertOrUnion i sv' im) im
| otherwise = error $ "insertRowSM : incompatible dimensions " ++ show (n, d)
| otherwise = error $ "insertRowWith : incompatible dimensions " ++ show (n, d)
where sv' = I.mapKeys fj sv
insertOrUnion i' sv' im' = maybe sv' (I.union sv') (I.lookup i' im')

Expand All @@ -87,9 +87,9 @@ insertRow = insertRowWith id
-- | Insert column, using the provided row index transformation function
insertColWith :: (IxRow -> IxRow) -> SpMatrix a -> SpVector a -> IxCol -> SpMatrix a
insertColWith fi smm sv j
| not (inBounds0 n j) = error "insertColSM : index out of bounds"
| not (inBounds0 n j) = error "insertColWith : index out of bounds"
| m >= mv = insIM2 smm vl j
| otherwise = error $ "insertColSM : incompatible dimensions " ++ show (m,mv) where
| otherwise = error $ "insertColWith : incompatible dimensions " ++ show (m,mv) where
(m, n) = dim smm
mv = dim sv
vl = toListSV sv
Expand Down Expand Up @@ -321,8 +321,8 @@ fromColsV qv = V.ifoldl' ins (zeroSM m n) qv where
-- | Pack a V.Vector of SpVectors as rows of an SpMatrix
fromRowsV :: V.Vector (SpVector a) -> SpMatrix a
fromRowsV qv = V.ifoldl' ins (zeroSM m n) qv where
n = V.length qv
m = svDim $ V.head qv
m = V.length qv
n = dim $ V.head qv
ins mm i c = insertRow mm c i


Expand Down
8 changes: 8 additions & 0 deletions test/LibSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,11 @@ b10 = aa10 #> x10
-- m2 = toCSR 4 4 $ V.fromList [(0,0,1), (0,2,5), (2,0,4), (2,3,1), (3,2,2)]
-- m3 = toCSR 4 4 $ V.fromList [(1,0,5), (1,1,8), (2,2,3), (3,1,6)]



-- | Test data for issue #42

x42, y42 :: SpVector Double
x42 = fromListSV 4 [(2,3)]
y42 = fromListSV 4 [(0,3)]

0 comments on commit 67cef26

Please sign in to comment.