From 9448e2ac192bb54d861323783b176c2a3d6934ea Mon Sep 17 00:00:00 2001 From: programsnail Date: Sun, 21 Apr 2024 14:09:38 +0300 Subject: [PATCH] typecheck (without modes), unique mode --- CMakeLists.txt | 1 + README.md | 11 +++ include/mode_check.hpp | 77 ++++++++++++++++ include/parsing_tree.hpp | 105 ++++++++++++---------- include/type_check.hpp | 74 ++++++++++++++++ include/types.hpp | 165 ++++++++++++++++++++++++++++++++++ include/utils.hpp | 39 +++++++++ src/main.cpp | 56 +++++++++++- src/mode_check.cpp | 148 ++++--------------------------- src/type_check.cpp | 185 +++++++++++++++++++-------------------- src/types.cpp | 13 +++ tests/tests.cpp | 7 +- 12 files changed, 602 insertions(+), 279 deletions(-) create mode 100644 include/types.hpp create mode 100644 include/utils.hpp create mode 100644 src/types.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e10ebd3..ad958c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,5 +17,6 @@ include_directories( add_executable(lang src/main.cpp src/parsing_tree.cpp + src/types.cpp src/type_check.cpp src/mode_check.cpp) diff --git a/README.md b/README.md index 82058d7..9b5a0e1 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ ## Info https://blog.janestreet.com/oxidizing-ocaml-locality/ + https://blog.janestreet.com/oxidizing-ocaml-ownership/ - **locality:** *global* (default) or *local* (value should not be passed out of context, can be allocated on stach) @@ -23,3 +24,13 @@ https://blog.janestreet.com/oxidizing-ocaml-ownership/ ## Examples - *unique:* let f (unique x) = x * x in f;; -> error + +--- + +**bad design decisions:** + +- shared_ptr instead of unique_ptr +- using namespace std +- use of indicies instead of visitor for std::variant + +going to fix later (?) diff --git a/include/mode_check.hpp b/include/mode_check.hpp index 3f59c93..be1db18 100644 --- a/include/mode_check.hpp +++ b/include/mode_check.hpp @@ -1,2 +1,79 @@ #pragma once +#include "parsing_tree.hpp" + +#include +#include + +namespace mode_check { + +using namespace types; + +struct VarState { + VarState(Mode mode, size_t count = 0) : mode(mode), count(count) {} + + Mode mode; + size_t count = 0; +}; + +struct State { + friend struct Context; + + State() { vars_stack.emplace_back(); } + + std::optional get_var_state(const std::string &name, + bool last_context_only = false) { + for (auto vars_it = vars_stack.rbegin(); vars_it != vars_stack.rend(); + ++vars_it) { + auto var_it = vars_it->find(name); + if (var_it == vars_it->end()) { + if (last_context_only) { + break; + } + + continue; + } + return &var_it->second; + } + + utils::throw_error("NO_VAR"); + return std::nullopt; + } + + void add_var(std::string name, Mode mode = Mode()) { + vars_stack.back().insert({std::move(name), VarState{mode}}); + // TODO: check existance + } + +private: + void enter_context() { vars_stack.emplace_back(); } + + void exit_context() { vars_stack.pop_back(); } + +private: + vector> vars_stack; +}; + +struct Context { + Context(State &state) : state_(state) { state_.enter_context(); } + + ~Context() { state_.exit_context(); } + +private: + State &state_; +}; + +// struct ExclVarScope { +// ExclVarScope(std::string name, State &state) +// : name_(std::move(name)), state_(state) {} + +// ~ExclVarScope() {} + +// private: +// std::string name_; +// State &state_; +// }; + +void check_expr(nodes::ExprPtr expr, State &state); + +} // mode_check diff --git a/include/parsing_tree.hpp b/include/parsing_tree.hpp index 78a5cd6..a8b1268 100644 --- a/include/parsing_tree.hpp +++ b/include/parsing_tree.hpp @@ -6,81 +6,68 @@ #include #include -namespace types { - -using namespace std; - -struct Type; -using TypePtr = shared_ptr; - -struct ArrowType { - vector types; -}; - -struct BoolType {}; -struct IntType {}; -// struct UnitType {}; - -struct AnyType {}; - -struct Type { - static constexpr size_t ARROW_TYPE_INDEX = 0; - variant type; - - enum class Loc { GLOBAL, LOCAL } loc = Loc::GLOBAL; - enum class Uniq { SHARED, UNIQUE, EXCL } uniq = Uniq::SHARED; - enum class Lin { MANY, ONCE, SEP } lin = Lin::MANY; -}; - -template -Type make_type(Args&&... args) { - return Type{T{std::forward(args)...}}; -} - -} // namespace types +#include "utils.hpp" +#include "types.hpp" namespace nodes { using namespace std; -struct Node { - optional type = std::nullopt; +struct NodeInfo { + optional type = std::nullopt; + optional mode = std::nullopt; }; struct Expr; using ExprPtr = shared_ptr; using ExprPtrV = std::vector; -struct Arg : public Node { +struct Arg : public NodeInfo { + Arg(string name) : name(std::move(name)) {} + string name; }; -struct Const : public Node { +struct Const : public NodeInfo { + Const(int value) : value(value) {} + int value; }; -struct Var : public Node { +struct Var : public NodeInfo { + Var(string name) : name(std::move(name)) {} + string name; }; -struct Let : public Node { +struct Let : public NodeInfo { + Let(Arg name, ExprPtr body, ExprPtr where) + : name(std::move(name)), body(body), where(where) {} + Arg name; ExprPtr body; ExprPtr where; }; -struct Lambda : public Node { +struct Lambda : public NodeInfo { + Lambda(vector args, ExprPtr expr) : args(std::move(args)), expr(expr) {} vector args; ExprPtr expr; }; -struct Call : public Node { +struct Call : public NodeInfo { + Call(ExprPtr func, vector args) + : func(func), args(std::move(args)) {} + ExprPtr func; vector args; }; -struct Condition : public Node { +struct Condition : public NodeInfo { + Condition(ExprPtr condition, ExprPtr then_case, ExprPtr else_case) + : condition(condition), then_case(then_case), else_case(else_case) {} + ExprPtr condition; ExprPtr then_case; ExprPtr else_case; @@ -96,17 +83,39 @@ struct Expr { variant value; }; -template -ExprPtr make_expr(Args&&... args) { - return std::make_shared(T{std::forward(args)...}); +template ExprPtr make_expr(Args &&...args) { + return std::make_shared(T(std::forward(args)...)); } -static ExprPtr lambda1(string name, ExprPtr expr) { - return make_expr(vector{{name}}, std::move(expr)); +template inline T with_type(T node, types::Type type) { + node.type = std::move(type); + return node; } -ExprPtr operator_call(string name, ExprPtr left, ExprPtr right) { - return make_expr(make_expr(name), ExprPtrV{left, right}); +template inline T with_mode(T node, types::Mode mode) { + node.mode = std::move(mode); + return node; +} + +template inline T with_unique(T node) { + return with_mode(node, types::Mode(types::Mode::Uniq::UNIQUE)); +} + +inline ExprPtr make_var(std::string name, types::Mode mode = types::Mode()) { + return make_expr(with_mode(Var(std::move(name)), mode)); +} + +inline ExprPtr lambda1(string var, ExprPtr expr) { + return make_expr(vector{Arg(var)}, std::move(expr)); +} + +inline ExprPtr lambda1(Arg var, ExprPtr expr) { + return make_expr(vector{var}, std::move(expr)); +} + +inline ExprPtr operator_call(string name, ExprPtr left, ExprPtr right, + types::Mode mode = types::Mode()) { + return make_expr(make_var(name, mode), ExprPtrV{left, right}); } // TODO: all constructors diff --git a/include/type_check.hpp b/include/type_check.hpp index e69de29..97c1222 100644 --- a/include/type_check.hpp +++ b/include/type_check.hpp @@ -0,0 +1,74 @@ +#pragma once + +#include "parsing_tree.hpp" + +#include +#include + +namespace type_check { + +using namespace types; + +struct VarManager { + friend struct Context; + + VarManager() { vars_stack.emplace_back(); } + + optional get_var_type(const std::string &name, + bool last_context_only = false) { + for (auto vars_it = vars_stack.rbegin(); vars_it != vars_stack.rend(); + ++vars_it) { + auto var_it = vars_it->find(name); + if (var_it == vars_it->end()) { + if (last_context_only) { + break; + } + + continue; + } + return var_it->second; + } + + utils::throw_error("NO_VAR"); + return std::nullopt; + } + + void add_var(std::string name, TypeID type) { + vars_stack.back().insert({std::move(name), type}); + // TODO: check existance + } + +private: + void enter_context() { vars_stack.emplace_back(); } + + void exit_context() { vars_stack.pop_back(); } + +private: + vector> vars_stack; +}; + +struct Context { + Context(VarManager &manager) : manager_(manager) { manager_.enter_context(); } + + ~Context() { manager_.exit_context(); } + +private: + VarManager &manager_; +}; + +// --------------- + +struct State { + types::Storage type_storage; + VarManager manager; +}; + +// struct GenericVarContext { +// GenericVarContext() { /*introduce generic*/ } +// ~GenericVarContext() { /*resolve generic (two ways: as let, or as func +// arg)*/ } +// }; + +types::TypeID check_expr(nodes::ExprPtr expr, State &state); + +} // namespace type_check diff --git a/include/types.hpp b/include/types.hpp new file mode 100644 index 0000000..929683a --- /dev/null +++ b/include/types.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include +#include +#include + +namespace types { + +using namespace std; + +struct Mode { + enum class Loc { GLOBAL, LOCAL } loc = Loc::GLOBAL; + enum class Uniq { SHARED, UNIQUE, EXCL } uniq = Uniq::SHARED; + enum class Lin { MANY, ONCE, SEP } lin = Lin::MANY; + + Mode with(Loc mode) const { + Mode copy = *this; + copy.loc = mode; + return copy; + } + Mode with(Uniq mode) const { + Mode copy = *this; + copy.uniq = mode; + return copy; + } + Mode with(Lin mode) const { + Mode copy = *this; + copy.lin = mode; + return copy; + } + + Mode() = default; + Mode(Loc mode) : loc(mode) {} + Mode(Uniq mode) : uniq(mode) {} + Mode(Lin mode) : lin(mode) {} +}; +using ModePtr = shared_ptr; + +struct Storage; +struct Type; +struct TypeID { + TypeID(size_t id, Storage *storage) : id(id), storage(storage) {} + + const Type &get() const; + Type &get(); + +private: + size_t id; + Storage *storage = nullptr; +}; +using TypeIDV = vector; + +struct ArrowType { + vector types; +}; + +struct BoolType {}; +struct IntType {}; +// struct UnitType {}; + +struct GenericType { + size_t id; +}; + +struct Type { + variant type; +}; + +template Type make_type(Args &&...args) { + return Type{T{std::forward(args)...}}; +} + +template Type make_func1(TypeID in, TypeID ret) { + return make_type(TypeIDV{in, ret}); +} + +inline Type make_operator(TypeID left, TypeID right, TypeID ret) { + return make_type(TypeIDV{left, right, ret}); +} + +struct Storage { + Storage() + : int_type(add(make_type())), + bool_type(add(make_type())) {} + + TypeID get_int_type() { return int_type; } + TypeID get_bool_type() { return bool_type; } + + Type &get_type(size_t id) { return types[id]; } + + const Type &get_type(size_t id) const { return types[id]; } + + TypeID introduce_new_generic() { + return add(make_type(first_unused_generic_id++)); + } + + TypeID add(Type type) { + types.push_back(std::move(type)); + return TypeID(types.size() - 1, this); + } + + // TODO: add modes ?? + bool unify(TypeID left_id, TypeID right_id) { + Type &left = left_id.get(); + Type &right = right_id.get(); + + if (const auto *left_generic = get_if(&left.type); + left_generic != nullptr) { + // TODO: check if other type contins generic + resolve(*left_generic, right); + return true; + } + + if (const auto *right_generic = get_if(&right.type); + right_generic != nullptr) { + // TODO: check if other type contins generic + resolve(*right_generic, left); + return true; + } + + if (left.type.index() != right.type.index()) { + return false; + } + + if (holds_alternative(left.type)) { + const auto &left_types = std::get(left.type).types; + const auto &right_types = std::get(right.type).types; + + if (left_types.size() != right_types.size()) { + return false; + } + + bool all_unify_passed = true; + for (size_t i = 0; i < left_types.size(); ++i) { + if (not unify(left_types[i], right_types[i])) { + all_unify_passed = false; + } + } + + return all_unify_passed; + } + + return true; + } + + void resolve(GenericType generic, const Type &replacement) { + for (auto &type : types) { + if (const auto *generic_type = get_if(&type.type); + generic_type != nullptr and generic_type->id == generic.id) { + type = replacement; + } + } + } + +private: + size_t first_unused_generic_id = 0; + + vector types; + + TypeID int_type; + TypeID bool_type; +}; + +} // namespace types + diff --git a/include/utils.hpp b/include/utils.hpp new file mode 100644 index 0000000..0509eba --- /dev/null +++ b/include/utils.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +namespace utils { + +using namespace std; + +// C++ 23 +[[noreturn]] inline void unreachable() { + // Uses compiler specific extensions if possible. + // Even if no extension is used, undefined behavior is still raised by + // an empty function body and the noreturn attribute. +#if defined(_MSC_VER) && !defined(__clang__) // MSVC + __assume(false); +#else // GCC, Clang + __builtin_unreachable(); +#endif +} + +// ----------------- + +// visitor helper +template struct overloaded : Ts... { + using Ts::operator()...; +}; + +struct Error { + string message; + source_location location; +}; + +inline void throw_error(string message, + source_location location = source_location::current()) { + throw Error{std::move(message), location}; +} + +} // namespace utils diff --git a/src/main.cpp b/src/main.cpp index 8b13b77..b41823b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,58 @@ +#include "mode_check.hpp" #include "parsing_tree.hpp" -#include "typechecker.hpp" +#include "type_check.hpp" + +#include + +auto make_program() { + using namespace nodes; + return make_expr(Arg("f"), + lambda1(with_unique(Arg("x")), + operator_call("+", make_var("x", types::Mode::Uniq::UNIQUE), + make_var("x", types::Mode::Uniq::UNIQUE))), + make_var("f")); +} + +void add_builtin_functions_types(type_check::State &state) { + auto sum_type = state.type_storage.add(types::make_operator( + state.type_storage.get_int_type(), state.type_storage.get_int_type(), + state.type_storage.get_int_type())); + state.manager.add_var("+", sum_type); +} + +void add_builtin_functions_modes(mode_check::State &state) { + state.add_var("+"); +} + +void print_error(const std::string &general_message, + const utils::Error &error) { + std::cerr << general_message << " " + << "file: " << error.location.file_name() << "(" + << error.location.line() << ":" << error.location.column() << ") `" + << error.location.function_name() << "`: " << error.message; +} int main() { - return 0; + const auto program = make_program(); + + try { + type_check::State state; + + add_builtin_functions_types(state); + + type_check::check_expr(program, state); + + } catch (utils::Error error) { + print_error("TYPE CHECK ERROR:", error); + } + + try { + mode_check::State state; + + add_builtin_functions_modes(state); + + mode_check::check_expr(program, state); + } catch (utils::Error error) { + print_error("MODE CHECK ERROR:", error); + } } diff --git a/src/mode_check.cpp b/src/mode_check.cpp index 4eedeff..d9062bd 100644 --- a/src/mode_check.cpp +++ b/src/mode_check.cpp @@ -1,140 +1,24 @@ #include "mode_check.hpp" -#include "parsing_tree.hpp" - -#include -#include - namespace mode_check { -// C++ 23 -[[noreturn]] inline void unreachable() { - // Uses compiler specific extensions if possible. - // Even if no extension is used, undefined behavior is still raised by - // an empty function body and the noreturn attribute. -#if defined(_MSC_VER) && !defined(__clang__) // MSVC - __assume(false); -#else // GCC, Clang - __builtin_unreachable(); -#endif -} - -using namespace types; - -struct VarState { - Type type; - size_t count = 0; -}; - -struct ModeError { - enum Error { - UNKNOWN, - NO_VAR, - NO_VAR_TYPE, - NO_TYPE, - WRONG_TYPE, - LOCAL, - UNIQUE, - EXCL, - ONCE, - SEP - } error = UNKNOWN; - - source_location location; - - ModeError(Error error, source_location location) - : type(type), location(location) {} -}; - -struct State { - friend struct Context; - - State() { vars_stack.emplace_back(); } - - void set_error(ModeError::Error error, - source_location location = source_location::current()) { - if (first_error.has_value()) { - return; - } - - first_error = ModeError(std::move(error), location); - } - - std::optional get_var_state(const std::string &name, - bool last_context_only = false) { - for (auto vars_it = vars_stack.rbegin(); vars_it != vars_stack.rend(); - ++vars_it) { - auto var_it = vars_it->find(name); - if (var_it == vars_it->end()) { - if (last_context_only) { - break; - } - - continue; - } - return &var_it->second; - } - - set_error(ModeError::NO_VAR); - return std::nullopt; - } - - void add_var(std::string name, Type type) { - vars_stack.back().insert({std::move(name), VarState{std::move(type)}}); - // TODO: check existance - } - - std::optional get_first_error() { return first_error; } - -private: - void enter_context() { vars_stack.emplace_back(); } - - void exit_context() { vars_stack.pop_back(); } - -private: - vector> vars_stack; - std::optional first_error; -}; - -struct Context { - Context(State &state) : state_(state) { state_.enter_context(); } - - ~Context() { state_.exit_context(); } - -private: - State &state_; -}; - -// struct ExclVarScope { -// ExclVarScope(std::string name, State &state) -// : name_(std::move(name)), state_(state) {} - -// ~ExclVarScope() {} - -// private: -// std::string name_; -// State &state_; -// }; - -void check_expr(nodes::ExprPtr expr, State &state); - void check_const(const nodes::Const &, State &) {} void check_var(const nodes::Var &expr, State &state) { - if (not expr.type.has_value()) { - state.set_error(ModeError::NO_TYPE); + if (not expr.mode.has_value()) { + utils::throw_error("NO_MODE for " + expr.name); return; } - auto type = expr.type.value(); + auto mode = expr.mode.value(); if (auto maybe_var_state = state.get_var_state(expr.name); maybe_var_state.has_value()) { auto &var_state = *maybe_var_state.value(); - if (var_state.type.uniq == Type::Uniq::UNIQUE) { + if (var_state.mode.uniq == Mode::Uniq::UNIQUE) { ++var_state.count; - if (var_state.count > 1 || type.uniq != Type::Uniq::UNIQUE) { - state.set_error(ModeError::UNIQUE); + if (var_state.count > 1 || mode.uniq != Mode::Uniq::UNIQUE) { + utils::throw_error("UNIQUE for " + expr.name); return; } } @@ -151,10 +35,10 @@ void check_let(const nodes::Let &expr, State &state) { { Context context(state); - if (not expr.name.type.has_value()) { - state.set_error(ModeError::NO_VAR_TYPE); + if (not expr.name.mode.has_value()) { + utils::throw_error("NO_VAR_MODE"); } - state.add_var(expr.name.name, expr.name.type.value()); + state.add_var(expr.name.name, expr.name.mode.value()); check_expr(expr.where, state); } @@ -164,11 +48,11 @@ void check_lambda(const nodes::Lambda &expr, State &state) { Context context(state); for (const auto &arg : expr.args) { - if (not arg.type.has_value()) { - state.set_error(ModeError::NO_VAR_TYPE); + if (not arg.mode.has_value()) { + utils::throw_error("NO_VAR_MODE"); continue; } - state.add_var(arg.name, arg.type.value()); + state.add_var(arg.name, arg.mode.value()); } check_expr(expr.expr, state); @@ -176,20 +60,20 @@ void check_lambda(const nodes::Lambda &expr, State &state) { void check_call(const nodes::Call &expr, State &state) { // if (not expr.type.has_value()) { - // state.set_error(ModeError::NO_TYPE); + // utils::throw_error("NO_TYPE"); // return; // } // auto type = expr.type.value(); // if (not holds_alternative(type.type)) { - // state.set_error(ModeError::WRONG_TYPE); + // utils::throw_error("WRONG_TYPE"); // return; // } // const auto &arrow_type = get(type.type); // if (arrow_type.types.size() != expr.args.size() + 1) { - // state.set_error(ModeError::WRONG_TYPE); + // utils::throw_error("WRONG_TYPE"); // return; // } @@ -229,7 +113,7 @@ void check_expr(nodes::ExprPtr expr, State &state) { check_condition(std::get<5>(expr->value), state); break; default: - unreachable(); + utils::unreachable(); } } diff --git a/src/type_check.cpp b/src/type_check.cpp index 76b2464..e1ddf74 100644 --- a/src/type_check.cpp +++ b/src/type_check.cpp @@ -1,112 +1,107 @@ #include "type_check.hpp" -#include "parsing_tree.hpp" - -// TODO - namespace type_check { -using namespace types; - -// C++ 23 -[[noreturn]] inline void unreachable() { - // Uses compiler specific extensions if possible. - // Even if no extension is used, undefined behavior is still raised by - // an empty function body and the noreturn attribute. -#if defined(_MSC_VER) && !defined(__clang__) // MSVC - __assume(false); -#else // GCC, Clang - __builtin_unreachable(); -#endif +types::TypeID check_const(nodes::Const &expr, State &state) { + return (expr.type = state.type_storage.get_int_type()).value(); } -// bool eq(Type x, Type y) { -// if (x.type.index() != y.type.index()) { -// return false; -// } - -// if (x.type.index() != Type::ARROW_TYPE_INDEX) { -// return true; -// } - -// const auto &x_types = std::get(x.type); -// const auto &y_types = std::get(y.type); - -// if (x_types.size() != ) -// } - -// ----------------- - -template -using Typechecked = std::pair; - -Typechecked typecheck_expr(nodes::ExprPtr expr); - -Typechecked typecheck_const(nodes::Const expr) { - // TODO - return {std::move(expr), make_type()}; +types::TypeID check_var(nodes::Var &expr, State &state) { + if (auto maybe_var_type = state.manager.get_var_type(expr.name); + maybe_var_type.has_value()) { + return (expr.type = maybe_var_type).value(); + } + utils::unreachable(); } -Typechecked typecheck_var(nodes::Var expr) { - // TODO - return {std::move(expr), make_type()}; -} +types::TypeID check_let(nodes::Let &expr, State &state) { + Context context(state.manager); -Typechecked typecheck_let(nodes::Let expr) { - // TODO - return {std::move(expr), make_type()}; -} + types::TypeID new_type = state.type_storage.introduce_new_generic(); -Typechecked typecheck_lambda(nodes::Lambda expr) { - // TODO - return {std::move(expr), make_type()}; -} + state.manager.add_var(expr.name.name, new_type); -Typechecked typecheck_call(nodes::Call expr) { - // TODO - return {std::move(expr), make_type()}; -} + types::TypeID body_type = check_expr(expr.body, state); -Typechecked typecheck_condition(nodes::Condition expr) { - // const auto [condition_expr, condition_type] = typecheck_expr(expr.condition); - // expr.condition = std::move(condition_expr); - - // const auto [then_case_expr, then_case_type] = typecheck_expr(expr.then_case); - // expr.then_case = std::move(then_case_expr); - - // const auto [else_case_expr, else_type_type] = typecheck_expr(expr.else_case); - // expr.else_case = std::move(else_case_expr); - - return {std::move(expr), make_type()}; -} - -Typechecked typecheck_expr(nodes::ExprPtr expr) { - types::Type type; - - switch (expr->value.index()) { - case 0: // Const - std::tie(expr->value, type) = typecheck_const(std::get<0>(expr->value)); - break; - case 1: // Var - std::tie(expr->value, type) = typecheck_var(std::get<1>(expr->value)); - break; - case 2: // Let - std::tie(expr->value, type) = typecheck_let(std::get<2>(expr->value)); - break; - case 3: // Lambda - std::tie(expr->value, type) = typecheck_lambda(std::get<3>(expr->value)); - break; - case 4: // Call - std::tie(expr->value, type) = typecheck_call(std::get<4>(expr->value)); - break; - case 5: // Condition - std::tie(expr->value, type) = typecheck_condition(std::get<5>(expr->value)); - break; - default: - unreachable(); + if (not state.type_storage.unify(new_type, body_type)) { + utils::throw_error("DIFFERENT_TYPES"); } - return {std::move(expr), std::move(type)}; + types::TypeID where_type = check_expr(expr.where, state); + return (expr.type = where_type).value(); +} + +types::TypeID check_lambda(nodes::Lambda &expr, State &state) { + Context context(state.manager); + + for (const auto &arg : expr.args) { + types::TypeID new_type = state.type_storage.introduce_new_generic(); + state.manager.add_var(arg.name, new_type); + } + + types::TypeID lambda_type = check_expr(expr.expr, state); + return (expr.type = lambda_type).value(); +} + +types::TypeID check_call(nodes::Call &expr, State &state) { + types::TypeID func_type = check_expr(expr.func, state); + + if (auto *arrow_func_type = get_if(&func_type.get().type); + arrow_func_type != nullptr) { + + if (arrow_func_type->types.size() != expr.args.size() + 1) { + utils::throw_error("ARG_COUNT_MISMATCH"); + } + + for (size_t i = 0; i < expr.args.size(); ++i) { + types::TypeID arg_type = check_expr(expr.args[i], state); + if (not state.type_storage.unify(arrow_func_type->types[i], arg_type)) { + utils::throw_error("DIFFERENT_TYPES"); + } + } + + return (expr.type = arrow_func_type->types.back()).value(); + } + + utils::throw_error("FUNC_IS_NOT_ARROW_TYPE"); + utils::unreachable(); +} + +types::TypeID check_condition(nodes::Condition &expr, State &state) { + types::TypeID condition_type = check_expr(expr.condition, state); + + if (not state.type_storage.unify(condition_type, + state.type_storage.get_bool_type())) { + utils::throw_error("DIFFERENT_TYPES"); + } + + types::TypeID then_type = check_expr(expr.then_case, state); + types::TypeID else_type = check_expr(expr.else_case, state); + + if (not state.type_storage.unify(then_type, else_type)) { + utils::throw_error("DIFFERENT_TYPES"); + } + + return (expr.type = then_type).value(); +} + +types::TypeID check_expr(nodes::ExprPtr expr, State &state) { + switch (expr->value.index()) { + case 0: // Const + return check_const(std::get<0>(expr->value), state); + case 1: // Var + return check_var(std::get<1>(expr->value), state); + case 2: // Let + return check_let(std::get<2>(expr->value), state); + case 3: // Lambda + return check_lambda(std::get<3>(expr->value), state); + case 4: // Call + return check_call(std::get<4>(expr->value), state); + case 5: // Condition + return check_condition(std::get<5>(expr->value), state); + default: + utils::unreachable(); + } } } // namespace type_check diff --git a/src/types.cpp b/src/types.cpp new file mode 100644 index 0000000..2cc8c7b --- /dev/null +++ b/src/types.cpp @@ -0,0 +1,13 @@ +#include "types.hpp" + +namespace types { + +const Type& TypeID::get() const { + return storage->get_type(id); +} + +Type& TypeID::get() { + return storage->get_type(id); +} + +} // namespace types diff --git a/tests/tests.cpp b/tests/tests.cpp index 048667e..b4f69eb 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -1,11 +1,14 @@ #include "parsing_tree.hpp" +#include + using namespace nodes; int main() { const auto program = - Expr{Let{Arg{"f", {}}, + Expr(Let(Arg("f"), lambda1("x", operator_call("+", make_expr("x"), make_expr("x"))), - make_expr("f")}}; + make_expr("f"))); + }