prog_synthesis/escher.hs
2025-10-22 12:42:46 +03:00

557 lines
25 KiB
Haskell

import Control.Monad (guard, liftM, when, foldM, foldM_)
import Control.Applicative
import Control.Monad.State as State
import Data.Map (Map)
import Data.Set (Set, insert, delete)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List (inits)
import Data.Maybe (fromMaybe, isJust)
data Value = BoolV Bool
| IntV Int
| ListV [Value]
| TreeV Tree
deriving (Read, Show, Eq, Ord)
data Tree = TNode { treeLeft :: Tree, treeRoot :: Value, treeRight :: Tree }
| TLeaf Value
deriving (Read, Show, Eq, Ord)
data Type = BoolT
| IntT
| ListT
| TreeT
deriving (Read, Show, Eq, Ord)
data Expr = Expr :&&: Expr -- Bool
| Expr :||: Expr
| NotE Expr
| Expr :=: Expr
| Leq0 Expr
| IsEmptyE Expr
| Expr :+: Expr -- Int
| Expr :-: Expr
| IncE Expr
| DecE Expr
| ZeroE
| Div2E Expr
| TailE Expr -- List
| HeadE Expr
| Expr :++: Expr -- cat
| Expr ::: Expr -- cons
| EmptyListE
| IsLeafE Expr -- Tree
| TreeValE Expr
| TreeLeftE Expr
| TreeRightE Expr
| CreateNodeE { nodeLeft :: Expr, nodeRoot :: Expr, nodeRight :: Expr }
| CreateLeafE Expr
| IfE { ifCond :: Expr, ifDoThen :: Expr, ifDoElse :: Expr }-- Control
| SelfE Expr
| InputE Expr
| Hole
deriving (Read, Show, Eq, Ord)
data Conf = Conf {confInput :: [Value],
confOracle :: Oracle,
confProg :: Expr,
confExamples :: [[Value]]}
------------
data Result a = Result a
| NewExamples [([Value], Value)]
| Error String
deriving (Read, Show, Eq)
instance Applicative Result where
Result f <*> Result x = Result $ f x
NewExamples es <*> NewExamples es' = NewExamples $ es ++ es'
Error err <*> _ = Error err
_ <*> Error err = Error err
NewExamples es <*> _ = NewExamples es
_ <*> NewExamples es = NewExamples es
pure = Result
-- m1 <*> m2 = m1 >>= (\x1 -> m2 >>= (\x2 -> return (x1 x2)))
instance Monad Result where
Result x >>= f = f x
NewExamples es >>= _ = NewExamples es
Error err >>= _ = Error err
return = pure
instance Alternative Result where
empty = Error "empty"
Error err <|> y = y
NewExamples es <|> _ = NewExamples es
r@(Result x) <|> _ = r
instance Functor Result where
fmap = liftM
instance MonadFail Result where
fail _ = Error "failure"
-- TODO: check all laws
------------
typeOf :: Value -> Type
typeOf (BoolV {}) = BoolT
typeOf (IntV {}) = IntT
typeOf (ListV {}) = ListT
typeOf (TreeV {}) = TreeT
isBool = (== BoolT) . typeOf
isInt = (== IntT) . typeOf
isList = (== ListT) . typeOf
isTree = (== TreeT) . typeOf
treeHeight :: Tree -> Int
treeHeight (TLeaf {}) = 1
treeHeight TNode { treeLeft, treeRoot, treeRight } = 1 + (max (treeHeight treeLeft) (treeHeight treeRight) :: Int)
-- TODO: check
structuralLess :: Value -> Value -> Bool
structuralLess (BoolV left) (BoolV right) = not left && right
structuralLess (IntV left) (IntV right) = left < right && left > 0 -- ??
-- TODO: require same elems ?
structuralLess (ListV left) (ListV right) = length left < length right
-- TODO: require subtree ?
structuralLess (TreeV left) (TreeV right) = treeHeight left < treeHeight right
structuralLess _ _ = False
eval :: Conf -> Expr -> Result Value
eval conf (left :&&: right) = do BoolV leftB <- eval conf left
BoolV rightB <- eval conf right
return $ BoolV $ leftB && rightB
eval conf (left :||: right) = do BoolV leftB <- eval conf left
BoolV rightB <- eval conf right
return $ BoolV $ leftB || rightB
eval conf (NotE e) = do BoolV b <- eval conf e
return $ BoolV $ not b
eval conf (left :=: right) = do leftV <- eval conf left
rightV <- eval conf right
return $ BoolV $ leftV == rightV
eval conf (Leq0 e) = do IntV i <- eval conf e
return $ BoolV $ i <= 0
eval conf (IsEmptyE e) = do v <- eval conf e
case v of
ListV [] -> return $ BoolV True
ListV _ -> return $ BoolV False
_ -> Error $ "Can't take empty not from list" ++ show v
eval conf (left :+: right) = do IntV leftI <- eval conf left
IntV rightI <- eval conf right
return $ IntV $ leftI + rightI
eval conf (left :-: right) = do IntV leftI <- eval conf left
IntV rightI <- eval conf right
return $ IntV $ leftI - rightI
eval conf (IncE e) = do IntV i <- eval conf e
return $ IntV $ i + 1
eval conf (DecE e) = do IntV i <- eval conf e
return $ IntV $ i - 1
eval conf ZeroE = return $ IntV 0
eval conf (Div2E e) = do IntV i <- eval conf e
return $ IntV $ i `div` 2
eval conf (TailE e) = do ListV (_ : t) <- eval conf e
return $ ListV t
eval conf (HeadE e) = do ListV (h : _) <- eval conf e
return h
eval conf (left :++: right) = do ListV leftL <- eval conf left
ListV rightL <- eval conf right
return $ ListV $ leftL ++ rightL
eval conf (left ::: right) = do leftV <- eval conf left
ListV rightL <- eval conf right
return $ ListV $ leftV : rightL
eval conf EmptyListE = return $ ListV []
eval conf (IsLeafE e) = do TreeV t <- eval conf e
return $ BoolV $ case t of
TNode {} -> False
TLeaf {} -> True
eval conf (TreeValE e) = do TreeV t <- eval conf e
return $ case t of
n@TNode {} -> treeRoot n
TLeaf e -> e
eval conf (TreeLeftE e) = do TreeV n@(TNode {}) <- eval conf e
return $ TreeV $ treeLeft n
eval conf (TreeRightE e) = do TreeV n@(TNode {}) <- eval conf e
return $ TreeV $ treeRight n
eval conf (CreateNodeE {nodeLeft, nodeRoot, nodeRight}) = do TreeV treeLeft <- eval conf nodeLeft
treeRoot <- eval conf nodeRoot
TreeV treeRight <- eval conf nodeRight
return $ TreeV $ TNode { treeLeft, treeRoot, treeRight }
eval conf (CreateLeafE e) = do v <- eval conf e
return $ TreeV $ TLeaf v
eval conf (IfE {ifCond, ifDoThen, ifDoElse}) = do BoolV condB <- eval conf ifCond
if condB then eval conf ifDoThen else eval conf ifDoElse
eval conf (SelfE e) = do ListV newInput <- eval conf e
-- NOTE: replaced guards for better errors description
-- guard $ length newInput == length (confInput conf)
-- guard $ and $ zipWith structuralLess newInput (confInput conf)
if length newInput /= length (confInput conf)
then Error $ "self call different length, new=" ++ show newInput ++ " old=" ++ show (confInput conf)
else do
if not $ and $ zipWith structuralLess newInput (confInput conf)
then Error $ "self call on >= exprs, new=" ++ show newInput ++ " old=" ++ show (confInput conf)
else do
if newInput `notElem` confExamples conf then
(case confOracle conf newInput of
Just expectedV -> NewExamples [(newInput, expectedV)]
Nothing -> Error $ "no oracle output on " ++ show newInput) -- TODO: ???
else eval conf{ confInput = newInput } (confProg conf)
eval conf (InputE e) = do IntV i <- eval conf e
if i < 0 || i >= length (confInput conf) -- NOTE: replaced guard for better errors description
then Error $ "can't access input " ++ show (confInput conf) ++ " by id " ++ show i
else return $ confInput conf !! i -- use !? instead (?)
eval _ Hole = Error "can't eval hole"
------------
type Oracle = [Value] -> Maybe Value
-- bipartite graph, root is Goal
newtype Goal = Goal [Maybe Value] -- result or unimportant
deriving (Read, Show, Eq, Ord)
-- Map sovled :: Goal -> Expr
-- Set unsolved
-- List Resolvers
data Resolver = Resolver { resolverGoal :: Goal,
resolverCond :: Goal,
resolverThen :: Goal,
resolverElse :: Goal } -- ids ??
data Synt = Synt { syntExprs :: [(Expr, [Maybe Value])],
syntSolvedGoals :: Map Goal Expr,
syntUnsolvedGoals :: Set Goal,
syntResolvers :: [Resolver],
syntExamples :: [[Value]],
syntOracle :: Oracle,
syntRoot :: Goal}
type SyntState a = State Synt a
------------
--fill holes in expr with top-level holes
fillHoles :: Expr -> [Expr] -> Expr
fillHoles (Hole :&&: Hole) [left, right] = left :&&: right
fillHoles (Hole :||: Hole) [left, right] = left :||: right
fillHoles (NotE Hole) [e] = NotE e
fillHoles (Hole :=: Hole) [left, right] = left :=: right
fillHoles (Leq0 Hole) [e] = Leq0 e
fillHoles (IsEmptyE Hole) [e] = IsEmptyE e
fillHoles (Hole :+: Hole) [left, right] = left :+: right
fillHoles (Hole :-: Hole) [left, right] = left :-: right
fillHoles (IncE Hole) [e] = IncE e
fillHoles (DecE Hole) [e] = DecE e
fillHoles ZeroE [] = ZeroE
fillHoles (Div2E Hole) [e] = Div2E e
fillHoles (TailE Hole) [e] = TailE e
fillHoles (HeadE Hole) [e] = HeadE e
fillHoles (Hole :++: Hole) [left, right] = left :++: right
fillHoles (Hole ::: Hole) [left, right] = left ::: right
fillHoles EmptyListE [] = EmptyListE
fillHoles (IsLeafE Hole) [e] = IsLeafE e
fillHoles (TreeValE Hole) [e] = TreeValE e
fillHoles (TreeLeftE Hole) [e] = TreeLeftE e
fillHoles (TreeRightE Hole) [e] = TreeRightE e
fillHoles (CreateNodeE {nodeLeft = Hole, nodeRoot = Hole, nodeRight = Hole})
[nodeLeft, nodeRoot, nodeRight] = CreateNodeE {nodeLeft, nodeRoot, nodeRight}
fillHoles (CreateLeafE Hole) [e] = CreateLeafE e
fillHoles (IfE {ifCond = Hole, ifDoThen = Hole, ifDoElse = Hole})
[ifCond, ifDoThen, ifDoElse] = IfE {ifCond, ifDoThen, ifDoElse}
fillHoles (SelfE Hole) [e] = SelfE e
fillHoles (InputE Hole) [e] = InputE e
fillHoles _ _ = undefined
confBySynt :: [Value] -> Expr -> Synt -> Conf
confBySynt input expr st = Conf {confInput = input,
confOracle = syntOracle st,
confProg = expr,
confExamples = syntExamples st}
matchGoal :: Goal -> Synt -> Expr -> Bool
matchGoal (Goal goal) st expr = let examples = syntExamples st in
foldl checkOnInput True $ zip examples goal
where checkOnInput False _ = False
checkOnInput acc (input, output) = let output' = eval (confBySynt input expr st) expr in
matchValue output' output -- TODO
matchValue (Result x) (Just y) = x == y
matchValue _ Nothing = True
matchValue _ _ = False
------ syntesis steps
calcExprOutputs :: Expr -> SyntState [Result Value]
calcExprOutputs expr = gets (\st -> map (\input -> eval (confBySynt input expr st) expr) $ syntExamples st)
matchAnyOutputs :: [Result Value] -> SyntState Bool
matchAnyOutputs outputs = do exprs <- gets syntExprs
foldM step True $ map fst exprs
where step :: Bool -> Expr -> SyntState Bool
step False _ = return False
step True expr = do exprOutputs <- calcExprOutputs expr
return $ outputs == exprOutputs
-- generate next step of exprs, remove copies
forwardStep :: Expr -> [Expr] -> SyntState (Maybe Expr)
forwardStep comp args = do st <- get
let expr = fillHoles comp args
outputs <- calcExprOutputs expr
if evalState (matchAnyOutputs outputs) st then return Nothing else do
put st { syntExprs = (expr, []) : syntExprs st}
return $ Just expr
splitGoal :: Goal -> [Bool] -> Resolver
splitGoal resolverGoal@(Goal outputs) selector | length outputs == length selector =
let resolverCond = Goal $ map (Just . BoolV) selector in
let resolverThen = Goal $ zipWith (\v b -> if b then v else Nothing) outputs selector in
let resolverElse = Goal $ zipWith (\v b -> if b then Nothing else v) outputs selector in
Resolver { resolverGoal, resolverCond, resolverThen, resolverElse }
-- split goal by its index and by expr (if any answers matched), check if there is same goals to generated
splitGoalStep :: Goal -> [Bool] -> SyntState Resolver
splitGoalStep goal selector = do st <- get
let r = splitGoal goal selector
put st { syntUnsolvedGoals = Set.insert (resolverCond r) $
Set.insert (resolverThen r) $
Set.insert (resolverElse r) $
syntUnsolvedGoals st,
syntResolvers = r : syntResolvers st }
return r
-- TODO: use expr evaluated outputs ?
trySolveGoal :: Expr -> Goal -> SyntState Bool
trySolveGoal expr goal = do st <- get
if matchGoal goal st expr then do
put st { syntSolvedGoals = Map.insert goal expr $ syntSolvedGoals st,
syntUnsolvedGoals = Set.delete goal $ syntUnsolvedGoals st }
return True
else return False
isGoalSolved :: Goal -> SyntState Bool
isGoalSolved goal = gets (Map.member goal . syntSolvedGoals)
goalSolution :: Goal -> SyntState (Maybe Expr)
goalSolution goal = gets (Map.lookup goal . syntSolvedGoals)
-- find all goals solved by new expr, by expr id it's values on examples, remove solved goals
-- NOTE: goals expected to be resolved
resolveStep :: (Expr, Expr, Expr) -> Resolver -> SyntState ()
resolveStep (ifCond, ifDoThen, ifDoElse) r = do st <- get
let expr = IfE { ifCond, ifDoThen, ifDoElse }
let goal = resolverGoal r
put st { syntSolvedGoals = Map.insert goal expr $ syntSolvedGoals st,
syntUnsolvedGoals = Set.delete goal $ syntUnsolvedGoals st,
syntExprs = (expr, []) : syntExprs st }
tryResolve :: Resolver -> SyntState Bool
tryResolve r = do condSol <- goalSolution $ resolverCond r
thenSol <- goalSolution $ resolverThen r
elseSol <- goalSolution $ resolverElse r
case (condSol, thenSol, elseSol) of
(Just condExpr, Just thenExpr, Just elseExpr) -> do
resolveStep (condExpr, thenExpr, elseExpr) r
return True
_ -> return False
remakeSynt :: [[Value]] -> [Value] -> SyntState ()
remakeSynt newInputs newOutputs = do st <- get
let Goal oldOutputs = syntRoot st
let goals = zip (newInputs ++ syntExamples st)
(newOutputs ++ map (fromMaybe undefined) oldOutputs)
initSynt (syntOracle st) goals
modify (\st' -> st' { syntExprs = syntExprs st })
-- clear goal tree up to root, add example, calculate exprs on input (could be recursive ?)
saturateStep :: Expr -> SyntState Bool
saturateStep expr = do st <- get
let (newInputs, newOutputs) = unzip $ foldl (searchEx st) [] (syntExamples st)
let isExFound = null newInputs
when isExFound $ remakeSynt newInputs newOutputs
return isExFound
where searchEx st [] input = case eval (confBySynt input expr st) expr of
NewExamples exs -> exs
_ -> []
searchEx _ exs _ = exs
-- try to find terminating expr
terminateStep :: Expr -> SyntState (Maybe Expr)
terminateStep expr = do st <- get
return $ if matchGoal (syntRoot st) st expr
then Just expr else Nothing
------ patterns
patterns0 :: [Expr]
patterns0 = [ZeroE, EmptyListE]
patterns1 :: [Expr]
patterns1 = [NotE Hole, Leq0 Hole,
IsEmptyE Hole, IncE Hole,
DecE Hole, Div2E Hole,
TailE Hole, HeadE Hole,
IsLeafE Hole, TreeValE Hole,
TreeLeftE Hole, TreeRightE Hole,
CreateLeafE Hole, SelfE Hole,
InputE Hole]
patterns2 :: [Expr]
patterns2 = [Hole :&&: Hole,
Hole :||: Hole,
Hole :=: Hole,
Hole :+: Hole,
Hole :-: Hole,
Hole :++: Hole,
Hole ::: Hole]
patterns3 :: [Expr]
patterns3 = [CreateNodeE {nodeLeft = Hole, nodeRoot = Hole, nodeRight = Hole},
IfE {ifCond = Hole, ifDoThen = Hole, ifDoElse = Hole}]
------ generation
concatShuffle :: [[a]] -> [a]
concatShuffle xxs = let xxs' = filter (not . null) xxs in
if null xxs' then [] else
map head xxs' ++ concatShuffle (map tail xxs')
-- -> n, +1 for top expression
genNext1 :: [[Expr]] -> [Expr]
genNext1 = head
-- 1 2 3 ... (n - 1) + (n - 1) ... 1 -> n, +1 for top expression
genNext2 :: [[Expr]] -> [(Expr, Expr)]
genNext2 exprs = let len = length exprs in
let exprs' = tail exprs in
concatShuffle $
zipWith (\xs ys -> ([(x, y) | x <- xs, y <- ys])) exprs' $
reverse exprs'
-- map genNext2 [1, 1 2, 1 2 3, ..., 1 2 ... (n - 1)] + (n - 1) (n - 2) ... 1 -> n, +1 for top expression
genNext3 :: [[Expr]] -> [(Expr, Expr, Expr)]
genNext3 exprs = let exprs' = tail exprs in
let prefixes = map genNext2 $ tail $ inits exprs' in
let ends = reverse exprs' in
concatShuffle $
zipWith (\xys zs -> ([(x, y, z) | (x, y) <- xys, z <- zs])) prefixes ends
-- get list of patterns and holes for forward steps
genStep :: [[Expr]] -> [(Expr, [Expr])]
genStep [] = map (, []) patterns0
genStep xs = concatShuffle [[(p, [x]) | p <- patterns1, x <- genNext1 xs],
[(p, [x, y]) | p <- patterns2, (x, y) <- genNext2 xs],
[(p, [x, y, z]) | p <- patterns3, (x, y, z) <- genNext3 xs]]
------ algorithm
createSynt :: Oracle -> [([Value], Value)] -> Synt
createSynt oracle goals = let root = Goal $ map (Just . snd) goals in
Synt { syntExprs = [],
syntSolvedGoals = Map.empty,
syntUnsolvedGoals = Set.singleton root,
syntResolvers = [],
syntExamples = map fst goals,
syntOracle = oracle,
syntRoot = root}
initSynt :: Oracle -> [([Value], Value)] -> SyntState ()
initSynt oracle goals = put $ createSynt oracle goals
stepOnAddedExpr :: Expr -> SyntState (Maybe Expr)
stepOnAddedExpr expr = do exFound <- saturateStep expr
st <- get
if exFound then stepOnAddedExprs $ map fst $ syntExprs st else do -- redo prev exprs (including current)
maybeResult <- terminateStep expr
if isJust maybeResult then return maybeResult else do
exprOutputs <- calcExprOutputs expr
-- TODO
-- when (foldl (compareExprOutputs exprOutputs) True $ map fst $ syntExprs st) $ modify $ \st -> st { syntExprs = tail $ syntExprs st }
gets (foldM_ (const $ trySolveGoal expr) False . syntUnsolvedGoals) -- solve existing goals
gets (foldM_ (const tryResolve) False . syntResolvers)-- resolve existing goals
st <- get
put $ foldl (splitGoalsFold expr exprOutputs) st $ Set.toList $ syntUnsolvedGoals st
return Nothing
where splitGoalsFold expr outputs st goal@(Goal expected) = let matches = zipWith matchResult outputs expected in
if not $ or matches then st else
execState (do r <- splitGoalStep goal matches
-- TODO: always solve goal
trySolveGoal expr (resolverThen r)) st
matchResult (NewExamples {}) _ = False
matchResult _ Nothing = True
matchResult (Result x) (Just y) = x == y
compareExprOutputs outputs False _ = False
-- compareExprOutputs outputs True e = do eOutputs <- calcExprOutputs e
-- outputs == eOutputs
stepOnAddedExprs :: [Expr] -> SyntState (Maybe Expr)
stepOnAddedExprs = foldM step Nothing
where step :: Maybe Expr -> Expr -> SyntState (Maybe Expr)
step res@(Just {}) _ = return res
step Nothing expr = stepOnAddedExpr expr
-- TODO: throw away exprs with Errors (?)
stepOnNewExpr :: Expr -> [Expr] -> SyntState (Maybe Expr)
stepOnNewExpr comp args = do st <- get
expr <- forwardStep comp args
case expr of
Just expr' -> stepOnAddedExpr expr'
Nothing -> return Nothing
-- stages:
-- init state
-- 1. gen new step exprs
-- 2. process exprs by one
-- 3. try terminate / saturate
-- 4. try to solve existing goals
-- 5. make resolutions if goals solved
-- 6. split goals, where expr partially matched
syntesisStep :: Int -> [[Expr]] -> SyntState (Maybe Expr)
syntesisStep 0 _ = return Nothing
syntesisStep steps prevExprs = -- oracle should be defined on the providid examples
do let currentExprs = genStep prevExprs
result <- foldM step Nothing currentExprs
if isJust result
then return result
else syntesisStep (steps - 1) (map (uncurry fillHoles) currentExprs : prevExprs)
where step res@(Just {}) _ = return res
step Nothing expr = uncurry stepOnNewExpr expr
syntesis' :: [[Expr]] -> Int -> Oracle -> [[Value]] -> Maybe Expr
syntesis' exprs steps oracle inputs = -- oracle should be defined on the providid examples
let outputs = map (fromMaybe undefined . oracle) inputs in
evalState (syntesisStep steps exprs) (createSynt oracle $ zip inputs outputs)
syntesis :: Int -> Oracle -> [[Value]] -> Maybe Expr
syntesis = syntesis' []
------ examples
reverseOracle :: Oracle
reverseOracle [ListV xs] = Just $ ListV $ reverse xs
reverseOracle _ = Nothing
reverseExamples :: [[Value]]
reverseExamples = [[ListV [IntV 1, IntV 2, IntV 3]]]
---
stutterOracle :: Oracle
stutterOracle [ListV (x : xs)] = do ListV xs' <- stutterOracle [ListV xs]
return $ ListV $ x : x : xs'
stutterOracle [ListV []] = Just $ ListV []
stutterOracle _ = Nothing
stutterExamples :: [[Value]]
stutterExamples = [[ListV [IntV 1, IntV 2, IntV 3]], [ListV [IntV 2, IntV 3]], [ListV [IntV 3]], [ListV []]]
stutterExpr :: Expr
stutterExpr = IfE { ifCond = IsEmptyE (InputE ZeroE), ifDoThen = EmptyListE, ifDoElse = HeadE (InputE ZeroE) ::: (HeadE (InputE ZeroE) ::: SelfE (TailE (InputE ZeroE) ::: EmptyListE)) }
stutterConf :: Conf
stutterConf = Conf { confInput = head stutterExamples,
confOracle = stutterOracle,
confProg = stutterExpr,
confExamples = stutterExamples }
-- TODO: examples