diff --git a/README.md b/README.md index 9b5a0e1..cccc06f 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ https://blog.janestreet.com/oxidizing-ocaml-ownership/ - *local* check - *once* check - *separated* check -- type check (optional) +- type check + ## Examples diff --git a/include/parsing_tree.hpp b/include/parsing_tree.hpp index a8b1268..0abb6fb 100644 --- a/include/parsing_tree.hpp +++ b/include/parsing_tree.hpp @@ -15,7 +15,6 @@ using namespace std; struct NodeInfo { optional type = std::nullopt; - optional mode = std::nullopt; }; struct Expr; @@ -23,9 +22,11 @@ using ExprPtr = shared_ptr; using ExprPtrV = std::vector; struct Arg : public NodeInfo { - Arg(string name) : name(std::move(name)) {} + Arg(string name, types::Mode mode_hint = {}) : name(std::move(name)), mode_hint(mode_hint) {} string name; + + types::Mode mode_hint; }; struct Const : public NodeInfo { @@ -92,17 +93,13 @@ template inline T with_type(T node, types::Type type) { return node; } -template inline T with_mode(T node, types::Mode mode) { - node.mode = std::move(mode); - return node; +inline Arg with_mode_hint(Arg arg, types::Mode mode) { + arg.mode_hint = mode; + return arg; } -template inline T with_unique(T node) { - return with_mode(node, types::Mode(types::Mode::Uniq::UNIQUE)); -} - -inline ExprPtr make_var(std::string name, types::Mode mode = types::Mode()) { - return make_expr(with_mode(Var(std::move(name)), mode)); +inline Arg with_unique_hint(Arg arg) { + return with_mode_hint(std::move(arg), types::Mode(types::Mode::Uniq::UNIQUE)); } inline ExprPtr lambda1(string var, ExprPtr expr) { @@ -113,9 +110,8 @@ inline ExprPtr lambda1(Arg var, ExprPtr expr) { return make_expr(vector{var}, std::move(expr)); } -inline ExprPtr operator_call(string name, ExprPtr left, ExprPtr right, - types::Mode mode = types::Mode()) { - return make_expr(make_var(name, mode), ExprPtrV{left, right}); +inline ExprPtr operator_call(string name, ExprPtr left, ExprPtr right) { + return make_expr(make_expr(name), ExprPtrV{left, right}); } // TODO: all constructors diff --git a/include/type_check.hpp b/include/type_check.hpp index 97c1222..3161e34 100644 --- a/include/type_check.hpp +++ b/include/type_check.hpp @@ -29,7 +29,7 @@ struct VarManager { return var_it->second; } - utils::throw_error("NO_VAR"); + utils::throw_error("NO_VAR for " + name); return std::nullopt; } diff --git a/include/types.hpp b/include/types.hpp index 929683a..a6807e3 100644 --- a/include/types.hpp +++ b/include/types.hpp @@ -1,17 +1,20 @@ #pragma once +#include #include -#include #include +#include + +#include namespace types { using namespace std; struct Mode { - enum class Loc { GLOBAL, LOCAL } loc = Loc::GLOBAL; - enum class Uniq { SHARED, UNIQUE, EXCL } uniq = Uniq::SHARED; - enum class Lin { MANY, ONCE, SEP } lin = Lin::MANY; + enum class Loc { LOCAL = 0, GLOBAL = 1 } loc = Loc::GLOBAL; + enum class Uniq { UNIQUE = 0, EXCL = 1, SHARED = 2 } uniq = Uniq::SHARED; + enum class Lin { ONCE = 0, SEP = 1, MANY = 2 } lin = Lin::MANY; Mode with(Loc mode) const { Mode copy = *this; @@ -33,6 +36,25 @@ struct Mode { Mode(Loc mode) : loc(mode) {} Mode(Uniq mode) : uniq(mode) {} Mode(Lin mode) : lin(mode) {} + + auto operator<=>(const Mode &other) const { + return tie(loc, uniq, lin) <=> tie(other.loc, other.uniq, other.lin); + } + + bool is_submode(const Mode &other) const { + return loc <= other.loc and uniq <= other.uniq and lin <= other.lin; + } + + static Mode choose_min(const Mode &left, const Mode &right) { + Mode ans; + ans.loc = static_cast(std::min(static_cast(left.loc), + static_cast(right.loc))); + ans.uniq = static_cast(std::min(static_cast(left.uniq), + static_cast(right.uniq))); + ans.lin = static_cast(std::min(static_cast(left.lin), + static_cast(right.lin))); + return ans; + } }; using ModePtr = shared_ptr; @@ -44,6 +66,8 @@ struct TypeID { const Type &get() const; Type &get(); + TypeID with_mode(Mode new_mode) const; + private: size_t id; Storage *storage = nullptr; @@ -60,38 +84,75 @@ struct IntType {}; struct GenericType { size_t id; + std::string name; }; struct Type { + template Type(T type, Mode mode = {}) : type(type), mode(mode) {} + + Type with_mode(Mode new_mode) const { + Type copy = *this; + copy.mode = new_mode; + return copy; + } + variant type; + Mode mode; }; template Type make_type(Args &&...args) { return Type{T{std::forward(args)...}}; } -template Type make_func1(TypeID in, TypeID ret) { - return make_type(TypeIDV{in, ret}); +template +Type make_moded_type(Mode mode, Args &&...args) { + return make_type(std::forward(args)...).with_mode(mode); } -inline Type make_operator(TypeID left, TypeID right, TypeID ret) { - return make_type(TypeIDV{left, right, ret}); +template +Type make_func1(TypeID in, TypeID ret, Mode mode = {}) { + return make_moded_type(mode, TypeIDV{in, ret}); } +inline Type make_operator(TypeID left, TypeID right, TypeID ret, + Mode mode = {}) { + return make_moded_type(mode, TypeIDV{left, right, ret}); +} + +enum class UnifyModePolicy { + Ignore, // all mode differences ignored + ApplyStrongest, // unique > shared, modes changed + CheckLeftIsSubmode, // only check is performed +}; + struct Storage { - Storage() - : int_type(add(make_type())), - bool_type(add(make_type())) {} + Storage() {} - TypeID get_int_type() { return int_type; } - TypeID get_bool_type() { return bool_type; } + TypeID get_int_type(Mode mode = {}) { + auto it = int_types.find(mode); + if (it != int_types.end()) { + return it->second; + } + return int_types.insert({mode, add(make_moded_type(mode))}) + .first->second; + } + + TypeID get_bool_type(Mode mode = {}) { + auto it = bool_types.find(mode); + if (it != bool_types.end()) { + return it->second; + } + return bool_types.insert({mode, add(make_moded_type(mode))}) + .first->second; + } Type &get_type(size_t id) { return types[id]; } const Type &get_type(size_t id) const { return types[id]; } - TypeID introduce_new_generic() { - return add(make_type(first_unused_generic_id++)); + TypeID introduce_new_generic(std::string name, Mode mode = {}) { + return add(make_moded_type(mode, first_unused_generic_id++, + std::move(name))); } TypeID add(Type type) { @@ -99,22 +160,37 @@ struct Storage { return TypeID(types.size() - 1, this); } - // TODO: add modes ?? - bool unify(TypeID left_id, TypeID right_id) { + bool unify(TypeID left_id, TypeID right_id, UnifyModePolicy policy) { Type &left = left_id.get(); Type &right = right_id.get(); + switch (policy) { + case UnifyModePolicy::Ignore: + break; + case UnifyModePolicy::ApplyStrongest: + left.mode = Mode::choose_min(left.mode, right.mode); + right.mode = left.mode; + break; + case UnifyModePolicy::CheckLeftIsSubmode: + if (not left.mode.is_submode(right.mode)) { + return false; + } + break; + } + if (const auto *left_generic = get_if(&left.type); left_generic != nullptr) { - // TODO: check if other type contins generic - resolve(*left_generic, right); + // TODO: check if other type contains generic + std::clog << "left is resolved with policy <" << static_cast(policy) << ">\n"; + resolve(*left_generic, right, left.mode); return true; } if (const auto *right_generic = get_if(&right.type); right_generic != nullptr) { - // TODO: check if other type contins generic - resolve(*right_generic, left); + // TODO: check if other type contains generic + std::clog << "right is resolved with policy <" << static_cast(policy) << ">\n"; + resolve(*right_generic, left, right.mode); return true; } @@ -132,7 +208,7 @@ struct Storage { bool all_unify_passed = true; for (size_t i = 0; i < left_types.size(); ++i) { - if (not unify(left_types[i], right_types[i])) { + if (not unify(left_types[i], right_types[i], policy)) { all_unify_passed = false; } } @@ -143,11 +219,14 @@ struct Storage { return true; } - void resolve(GenericType generic, const Type &replacement) { + void resolve(GenericType generic, const Type &replacement, Mode mode = {}) { + std::clog << "generic type " << generic.name << " is resolved with mode==UNIQUE: <" + << (replacement.mode.uniq == Mode::Uniq::UNIQUE) << ">\n"; for (auto &type : types) { if (const auto *generic_type = get_if(&type.type); generic_type != nullptr and generic_type->id == generic.id) { type = replacement; + type.mode = mode; } } } @@ -157,9 +236,8 @@ private: vector types; - TypeID int_type; - TypeID bool_type; + map int_types; + map bool_types; }; } // namespace types - diff --git a/src/main.cpp b/src/main.cpp index b41823b..abfce2a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -4,24 +4,25 @@ #include -auto make_program() { +auto make_program(bool uniq) { using namespace nodes; - return make_expr(Arg("f"), - lambda1(with_unique(Arg("x")), - operator_call("+", make_var("x", types::Mode::Uniq::UNIQUE), - make_var("x", types::Mode::Uniq::UNIQUE))), - make_var("f")); + return make_expr( + Arg("f"), + lambda1(Arg("x", uniq ? types::Mode(types::Mode::Uniq::UNIQUE) : types::Mode()), + operator_call("+", make_expr("x"), make_expr("x"))), + make_expr("f")); } -void add_builtin_functions_types(type_check::State &state) { +void add_builtin_functions_types(type_check::State &state, bool uniq) { auto sum_type = state.type_storage.add(types::make_operator( - state.type_storage.get_int_type(), state.type_storage.get_int_type(), + state.type_storage.get_int_type(uniq ? types::Mode(types::Mode::Uniq::UNIQUE) : types::Mode()), + state.type_storage.get_int_type(uniq ? types::Mode(types::Mode::Uniq::UNIQUE) : types::Mode()), state.type_storage.get_int_type())); state.manager.add_var("+", sum_type); } void add_builtin_functions_modes(mode_check::State &state) { - state.add_var("+"); + state.add_var("+"); // mode ?? } void print_error(const std::string &general_message, @@ -29,21 +30,22 @@ void print_error(const std::string &general_message, std::cerr << general_message << " " << "file: " << error.location.file_name() << "(" << error.location.line() << ":" << error.location.column() << ") `" - << error.location.function_name() << "`: " << error.message; + << error.location.function_name() << "`: " << error.message << std::endl; } -int main() { - const auto program = make_program(); +int run_example(bool arg_uniq, bool sum_uniq) { + const auto program = make_program(arg_uniq); try { type_check::State state; - add_builtin_functions_types(state); + add_builtin_functions_types(state, sum_uniq); type_check::check_expr(program, state); } catch (utils::Error error) { print_error("TYPE CHECK ERROR:", error); + return 1; } try { @@ -54,5 +56,25 @@ int main() { mode_check::check_expr(program, state); } catch (utils::Error error) { print_error("MODE CHECK ERROR:", error); + return 1; + } + + std::cout << "CHECK DONE\n"; + + return 0; +} + +int main() { + int n = 0; + + while(true) { + std::cout << "--- START TEST ---\n"; + std::cout << "TEST ID (0 - 3): "; + std::cin >> n; + if (n < 0 or n >= 4) { + break; + } + run_example(n % 2 == 1, n / 4 == 1); + std::cout << "--- END TEST ---\n"; } } diff --git a/src/mode_check.cpp b/src/mode_check.cpp index d9062bd..a9972c3 100644 --- a/src/mode_check.cpp +++ b/src/mode_check.cpp @@ -5,11 +5,11 @@ namespace mode_check { void check_const(const nodes::Const &, State &) {} void check_var(const nodes::Var &expr, State &state) { - if (not expr.mode.has_value()) { - utils::throw_error("NO_MODE for " + expr.name); + if (not expr.type.has_value()) { + utils::throw_error("NO_TYPE for " + expr.name); return; } - auto mode = expr.mode.value(); + auto mode = expr.type.value().get().mode; if (auto maybe_var_state = state.get_var_state(expr.name); maybe_var_state.has_value()) { @@ -27,32 +27,27 @@ void check_var(const nodes::Var &expr, State &state) { } void check_let(const nodes::Let &expr, State &state) { - { - Context context(state); - check_expr(expr.body, state); + Context context(state); + + check_expr(expr.body, state); + + if (not expr.name.type.has_value()) { + utils::throw_error("NO_VAR_TYPE for " + expr.name.name); } + state.add_var(expr.name.name, expr.name.type.value().get().mode); - { - Context context(state); - - if (not expr.name.mode.has_value()) { - utils::throw_error("NO_VAR_MODE"); - } - state.add_var(expr.name.name, expr.name.mode.value()); - - check_expr(expr.where, state); - } + check_expr(expr.where, state); } void check_lambda(const nodes::Lambda &expr, State &state) { Context context(state); for (const auto &arg : expr.args) { - if (not arg.mode.has_value()) { - utils::throw_error("NO_VAR_MODE"); + if (not arg.type.has_value()) { + utils::throw_error("NO_VAR_TYPE for " + arg.name); continue; } - state.add_var(arg.name, arg.mode.value()); + state.add_var(arg.name, arg.type.value().get().mode); } check_expr(expr.expr, state); diff --git a/src/type_check.cpp b/src/type_check.cpp index e1ddf74..938e99f 100644 --- a/src/type_check.cpp +++ b/src/type_check.cpp @@ -17,14 +17,16 @@ types::TypeID check_var(nodes::Var &expr, State &state) { types::TypeID check_let(nodes::Let &expr, State &state) { Context context(state.manager); - types::TypeID new_type = state.type_storage.introduce_new_generic(); - + types::TypeID new_type = + state.type_storage.introduce_new_generic(expr.name.name, expr.name.mode_hint); + expr.name.type = new_type; state.manager.add_var(expr.name.name, new_type); types::TypeID body_type = check_expr(expr.body, state); - if (not state.type_storage.unify(new_type, body_type)) { - utils::throw_error("DIFFERENT_TYPES"); + if (not state.type_storage.unify(new_type, body_type, + UnifyModePolicy::CheckLeftIsSubmode)) { + utils::throw_error("DIFFERENT_TYPES_OR_MODES"); } types::TypeID where_type = check_expr(expr.where, state); @@ -34,8 +36,9 @@ types::TypeID check_let(nodes::Let &expr, State &state) { types::TypeID check_lambda(nodes::Lambda &expr, State &state) { Context context(state.manager); - for (const auto &arg : expr.args) { - types::TypeID new_type = state.type_storage.introduce_new_generic(); + for (auto &arg : expr.args) { + types::TypeID new_type = state.type_storage.introduce_new_generic(arg.name, arg.mode_hint); + arg.type = new_type; state.manager.add_var(arg.name, new_type); } @@ -55,8 +58,9 @@ types::TypeID check_call(nodes::Call &expr, State &state) { for (size_t i = 0; i < expr.args.size(); ++i) { types::TypeID arg_type = check_expr(expr.args[i], state); - if (not state.type_storage.unify(arrow_func_type->types[i], arg_type)) { - utils::throw_error("DIFFERENT_TYPES"); + if (not state.type_storage.unify(arrow_func_type->types[i], arg_type, + UnifyModePolicy::CheckLeftIsSubmode)) { + utils::throw_error("DIFFERENT_TYPES_OR_MODES"); } } @@ -71,14 +75,16 @@ types::TypeID check_condition(nodes::Condition &expr, State &state) { types::TypeID condition_type = check_expr(expr.condition, state); if (not state.type_storage.unify(condition_type, - state.type_storage.get_bool_type())) { + state.type_storage.get_bool_type(), + UnifyModePolicy::Ignore)) { utils::throw_error("DIFFERENT_TYPES"); } types::TypeID then_type = check_expr(expr.then_case, state); types::TypeID else_type = check_expr(expr.else_case, state); - if (not state.type_storage.unify(then_type, else_type)) { + if (not state.type_storage.unify(then_type, else_type, + UnifyModePolicy::ApplyStrongest)) { utils::throw_error("DIFFERENT_TYPES"); } diff --git a/src/types.cpp b/src/types.cpp index 2cc8c7b..447ccf1 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -10,4 +10,8 @@ Type& TypeID::get() { return storage->get_type(id); } +TypeID TypeID::with_mode(Mode new_mode) const { + return storage->add(get().with_mode(new_mode)); +} + } // namespace types