diff --git a/include/type_check_visitor.hpp b/include/type_check_visitor.hpp index 7096543..5174b45 100644 --- a/include/type_check_visitor.hpp +++ b/include/type_check_visitor.hpp @@ -163,6 +163,15 @@ private: info::definition::AnyType* defined_type, bool is_method); + void ResetReturnedAndBroughtTypes() { + if (returned_type_.has_value()) { + all_branches_returned_value_ = false; + } + if (brought_type_.has_value()) { + all_branches_brought_value_ = false; + } + } + // bool HandleBuiltinFunctionCall(FunctionCallExpression* node); @@ -177,8 +186,13 @@ private: std::unordered_set type_namespaces_; utils::IdType current_type_; + std::optional returned_type_; + bool all_branches_returned_value_ = true; + std::optional brought_type_; + bool all_branches_brought_value_ = true; + std::optional is_const_definition_; bool is_in_statement_ = false; diff --git a/src/type_check_visitor.cpp b/src/type_check_visitor.cpp index 3811182..98aaf3b 100644 --- a/src/type_check_visitor.cpp +++ b/src/type_check_visitor.cpp @@ -258,8 +258,13 @@ void TypeCheckVisitor::Visit(FunctionDefinitionStatement* node) { utils::IdType returned_type = current_type_; returned_type_ = std::nullopt; + all_branches_returned_value_ = true; Visitor::Visit(node->value); + if (!all_branches_returned_value_) { + error_handling::HandleTypecheckError("Not all branches return value", node->base); + } + if (!returned_type_.has_value()) { returned_type_ = current_type_; } @@ -479,6 +484,8 @@ void TypeCheckVisitor::Visit(Match* node) { // TODO: move value to match // TODO: several matches with one statement typecheck <- check proposed solution std::optional nearest_statement; for (ssize_t i = (ssize_t)node->matches.size() - 1; i >= 0; --i) { // TODO: internal contexts ?? + ResetReturnedAndBroughtTypes(); + current_type_ = value_type; context_manager_.EnterContext(); @@ -530,23 +537,27 @@ void TypeCheckVisitor::Visit(Condition* node) { utils::IdType type; for (size_t i = 0; i < node->conditions.size(); ++i) { - Visitor::Visit(node->conditions[i]); - if (!context_manager_.EqualValues(context_manager_.AddValue(info::type::InternalType::Bool, utils::ValueType::Tmp), current_type_)) { - error_handling::HandleTypecheckError("Condition statement condition is not bool expression", node->base); - } + ResetReturnedAndBroughtTypes(); - Visitor::Visit(node->statements[i]); + Visitor::Visit(node->conditions[i]); + if (!context_manager_.EqualValues(context_manager_.AddValue(info::type::InternalType::Bool, utils::ValueType::Tmp), current_type_)) { + error_handling::HandleTypecheckError("Condition statement condition is not bool expression", node->base); + } - if (i == 0) { - type = current_type_; - } else { - if (!context_manager_.EqualValues(type, current_type_)) { - error_handling::HandleTypecheckError("Condition statement cases have different types", node->base); - } + Visitor::Visit(node->statements[i]); + + if (i == 0) { + type = current_type_; + } else { + if (!context_manager_.EqualValues(type, current_type_)) { + error_handling::HandleTypecheckError("Condition statement cases have different types", node->base); } + } } if (node->statements.size() > node->conditions.size()) { + ResetReturnedAndBroughtTypes(); + Visitor::Visit(node->statements[node->conditions.size()]); if (!context_manager_.EqualValues(type, current_type_)) { @@ -646,56 +657,27 @@ void TypeCheckVisitor::Visit(LoopLoop* node) { // Statements, expressions, blocks, etc. ----------------- -// TODO: check, that last statement in function definition has return type void TypeCheckVisitor::Visit(Block* node) { // TODO: types can be different in statement - std::optional brought_type; - std::optional returned_type = returned_type_; + + all_branches_brought_value_ = true; context_manager_.EnterContext(); for (auto& statement : node->statements) { - brought_type_ = std::nullopt; - returned_type_ = std::nullopt; - Visitor::Visit(statement); - - if (brought_type_.has_value()) { - if (!brought_type.has_value()) { - brought_type = brought_type_.value(); - } else { - if (!context_manager_.EqualValues(brought_type.value(), brought_type_.value())) { - error_handling::HandleTypecheckError("Different brought types in block", node->base); - } - } - } - - if (returned_type_.has_value()) { - if (!returned_type.has_value()) { - returned_type = returned_type_.value(); - } else { - if (!context_manager_.EqualValues(returned_type.value(), returned_type_.value())) { - error_handling::HandleTypecheckError("Different returned types in block", node->base); - } - } - } } - context_manager_.EnterContext(); + context_manager_.ExitContext(); - if (brought_type.has_value() - && !context_manager_.EqualValues(brought_type.value(), - context_manager_.AddValue(info::type::InternalType::Unit, - utils::ValueType::Tmp)) - && !brought_type_.has_value()) { - error_handling::HandleTypecheckError("Different brought types in block (no return at end)", node->base); + if (!all_branches_brought_value_) { + error_handling::HandleTypecheckError("Different brought types in block", node->base); } - if (brought_type.has_value()) { - current_type_ = brought_type.value(); + if (brought_type_.has_value()) { + current_type_ = brought_type_.value(); } else { current_type_ = context_manager_.AddValue(info::type::InternalType::Unit, utils::ValueType::Tmp); } - returned_type_ = returned_type; brought_type_ = std::nullopt; node->base.type_ = current_type_; @@ -894,9 +876,23 @@ void TypeCheckVisitor::Visit(VariantExpression* node) { void TypeCheckVisitor::Visit(ReturnExpression* node) { Visitor::Visit(node->expression); if (node->is_from_definition) { - returned_type_ = current_type_; + if (returned_type_.has_value()) { + if (!context_manager_.EqualValues(returned_type_.value(), current_type_)) { + error_handling::HandleTypecheckError("Different returned types", node->base); + } + } else { + returned_type_ = current_type_; + } + all_branches_returned_value_ = true; } else { - brought_type_ = current_type_; + if (brought_type_.has_value()) { + if (!context_manager_.EqualValues(brought_type_.value(), current_type_)) { + error_handling::HandleTypecheckError("Different brought types", node->base); + } + } else { + brought_type_ = current_type_; + } + all_branches_brought_value_ = true; } current_type_ = context_manager_.AddValue(info::type::InternalType::Unit, utils::ValueType::Tmp); diff --git a/tests/test_code.lang b/tests/test_code.lang index ee2a19b..db8efe8 100644 --- a/tests/test_code.lang +++ b/tests/test_code.lang @@ -142,5 +142,8 @@ def func : s = { exec main { for i in (,0 ,1 ,2 ,3) do func: "abacaba" - ; print: ({ bring read: }) + ; print: ({ + if true then bring read: else () + bring "nothing" + }) }