-- see https://markkarpov.com/tutorial/megaparsec.html
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE OverloadedStrings #-}

module Main.Parser.Megaparsec (parse) where

import Control.Monad.Combinators.Expr
import Data.Functor.Identity qualified
import Data.Text
import Data.Void (Void)
import Main.Types qualified as M
import Text.Megaparsec as MP hiding (parse)
import Text.Megaparsec qualified as MP
import Text.Megaparsec.Char qualified as C
import Text.Megaparsec.Char.Lexer qualified as L

type Parser = Parsec Void Text

lexeme :: Parser a -> Parser a
lexeme = L.lexeme C.space

symbol :: Text -> Parser Text
symbol = L.symbol C.space

int :: Parser Int
int = lexeme $ L.signed (return ()) L.decimal

string :: Text -> Parser Text
string = C.string

container :: Text -> Text -> Parser a -> Parser a
container b e = between (symbol b) (symbol e)

parens :: Parser a -> Parser a
parens = container "(" ")"

intExprTerm :: ParsecT Void Text Data.Functor.Identity.Identity M.Int
intExprTerm =
  choice
    [ M.Int <$> int,
      parens intExpr
    ]

intExprTable :: [[Operator Parser M.Int]]
intExprTable =
  [ [ binaryOp "*" (M.IntArith M.Mul),
      binaryOp "/" (M.IntArith M.Div)
    ],
    [ binaryOp "+" (M.IntArith M.Add),
      binaryOp "-" (M.IntArith M.Sub)
    ]
  ]

intExpr :: Parser M.Int
intExpr = makeExprParser intExprTerm intExprTable

intOrdCmpExpr :: ParsecT Void Text Data.Functor.Identity.Identity (M.OrdCmpOp, M.Int, M.Int)
intOrdCmpExpr = do
  b <- intExpr
  a <-
    choice
      [ M.GT <$ symbol ">",
        M.GTE <$ symbol ">=",
        -- M.Eq <$ string "==",
        -- M.Neq <$ string "!=",
        M.LTE <$ symbol "<=",
        M.LT <$ symbol "<"
      ]
  c <- intExpr
  return (a, b, c)

uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 f (x, y, z) = f x y z

boolExprTerm :: ParsecT Void Text Data.Functor.Identity.Identity M.Bool
boolExprTerm =
  choice
    [ uncurry3 M.IntOrdCmp <$> intOrdCmpExpr,
      M.Bool True <$ string "true",
      M.Bool False <$ string "false",
      parens boolExprTerm
    ]

-- boolExprTable :: [[Operator Parser M.Bool]]
-- boolExprTable =
--   [ [ binaryOp "<" (M.IntOrdCmp M.LT),
--       binaryOp "<=" (M.IntOrdCmp M.LTE),
--       binaryOp ">" (M.IntOrdCmp M.GT),
--       binaryOp ">=" (M.IntOrdCmp M.GTE)
--     ],
--     [ binaryOp "==" (M.IntOrdCmp M.Eq),
--       binaryOp "!=" (M.IntOrdCmp M.Neq)
--     ]
--   ]

-- boolExpr :: Parser M.Bool
-- boolExpr = makeExprParser boolExprTerm boolExprTable

binaryOp name f = InfixL $ f <$ symbol name

statement :: Parser M.Statement
statement =
  choice
    [ string "printInt" *> (M.PrintInt <$> parens intExpr),
      string "printBool" *> (M.PrintBool <$> parens boolExprTerm)
    ]
    <* symbol ";"

parseStatements :: Text -> Either (ParseErrorBundle Text Void) [M.Statement]
parseStatements = MP.parse (C.space *> many statement <* eof) ""

parse :: Text -> [M.Statement]
parse t =
  case parseStatements t of
    Right r -> r

-- TODO: add error handling