From 3815f8259b8d56e4afee3d7cf683d38163a3011a Mon Sep 17 00:00:00 2001 From: ProgramSnail Date: Fri, 7 Jul 2023 22:53:14 +0300 Subject: [PATCH] part of type deduction done, debugging in proces --- include/types.hpp | 27 +++++--- src/execute_visitor.cpp | 9 +++ src/type_check_visitor.cpp | 124 +++++++++++++++++++++++++++++++++++-- src/types.cpp | 69 +++++++++++++-------- src/utils.cpp | 2 + tests/test_code.lang | 40 ++++++------ 6 files changed, 212 insertions(+), 59 deletions(-) diff --git a/include/types.hpp b/include/types.hpp index 29c837e..5e7cb01 100644 --- a/include/types.hpp +++ b/include/types.hpp @@ -34,7 +34,8 @@ public: bool Require(const AbstractType& type) const; bool DeduceContext(const AbstractType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -74,7 +75,8 @@ public: bool Require(const DefinedType& type) const; bool DeduceContext(const DefinedType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -179,7 +181,8 @@ public: bool Require(const TupleType& type) const; bool DeduceContext(const TupleType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -208,7 +211,8 @@ public: bool Require(const VariantType& type) const; bool DeduceContext(const VariantType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -240,7 +244,8 @@ public: bool Require(const OptionalType& type) const; bool DeduceContext(const OptionalType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -273,7 +278,8 @@ public: bool Require(const ReferenceToType& type) const; bool DeduceContext(const ReferenceToType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -300,7 +306,8 @@ public: bool Require(const FunctionType& type) const; bool DeduceContext(const FunctionType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -325,7 +332,8 @@ public: bool Require(const ArrayType& type) const; bool DeduceContext(const ArrayType& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; @@ -351,7 +359,8 @@ public: bool Require(const Type& type) const; // TODO: check abstract type requirements for not abstract types bool DeduceContext(const Type& actual_type, - std::unordered_map>& context) const; + std::unordered_map>& context, + TypeManager& type_manager) const; std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; diff --git a/src/execute_visitor.cpp b/src/execute_visitor.cpp index ff2a219..16aded5 100644 --- a/src/execute_visitor.cpp +++ b/src/execute_visitor.cpp @@ -1088,6 +1088,15 @@ bool ExecuteVisitor::HandleBuiltinFunctionCall(FunctionCallExpression* node) { error_handling::HandleInternalError("Error function finished", "ExecuteVisitor.HandleBuiltinFunctionCall", &node->base); + } else if (node->name == "some") { // TODO: manage value type ?? + Visitor::Visit(node->arguments[0].second); + current_value_ = context_manager_.AddValue(info::value::OptionalValue(current_value_, + context_manager_.GetValueManager()), + utils::ValueType::Tmp); + } else if (node->name == "none") { // TODO: manage value type ?? + current_value_ = context_manager_.AddValue(info::value::OptionalValue(std::nullopt, + context_manager_.GetValueManager()), + utils::ValueType::Tmp); } else { return false; } diff --git a/src/type_check_visitor.cpp b/src/type_check_visitor.cpp index 6b676d9..fd2c7b7 100644 --- a/src/type_check_visitor.cpp +++ b/src/type_check_visitor.cpp @@ -764,6 +764,10 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { &node->base); } + if (node->name == "some") { + error_handling::DebugPrint("0"); + } + // try to find function declaration if (node->prefix.has_value()) { if (std::holds_alternative>(node->prefix.value())) { @@ -784,6 +788,10 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { maybe_function_declaration = FindFunctionAndUpdate(node); } + if (node->name == "some") { + error_handling::DebugPrint("1"); + } + // function declaration check if (!maybe_function_declaration.has_value()) { error_handling::HandleTypecheckError("No function declaration found for function in call expression", node->base); @@ -794,11 +802,10 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { // check & collect parmeters if (function_declaration->parameters.size() != node->parameters.size()) { - // if (node->parameters.size() != 0) { + if (node->parameters.size() != 0) { error_handling::HandleTypecheckError("Mismatched parameter count in function call expression", node->base); - // } - // deduce_parameters = true; - // TODO: paramters deduction + } + deduce_parameters = true; } for (size_t i = 0; i < node->parameters.size(); ++i) { Visit(node->parameters[i].get()); @@ -819,6 +826,12 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { function_declaration->parameters[i]->graph_id_, typeclass_graph_), utils::ValueType::Tmp)); + + // TODO: type requirements check needed ?? + } + + if (node->name == "some") { + error_handling::DebugPrint("2"); } size_t index_shift = (node->is_method_of_first_argument_ ? 1 : 0); @@ -830,6 +843,15 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { bool all_arguments = function_declaration->type->types.size() <= node->arguments.size() + 1 - index_shift; + std::unordered_map> deduced_context; + for (auto& parameter : function_declaration->parameters) { + deduced_context[parameter->type] = std::nullopt; + } + + if (node->name == "some") { + error_handling::DebugPrint("3"); + } + { size_t i = index_shift; // function call argument id size_t j = 0; // actual argument id @@ -851,6 +873,10 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { } } + if (node->name == "some") { + error_handling::DebugPrint("4"); + } + Visit(function_declaration->type->types[j].second.get()); utils::IdType argument_type = TypeInContext(current_type_, context); @@ -862,9 +888,50 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { } } + if (node->name == "some") { + error_handling::DebugPrint("5"); + } + Visitor::Visit(node->arguments[i].second); - if (!context_manager_.AddValueRequirement(current_type_, argument_type)) { - error_handling::HandleTypecheckError("Wrong argument type (argument " + std::to_string(j + 1) + ")", node->base); + + if (node->name == "some") { + error_handling::DebugPrint("6"); + } + + if (deduce_parameters) { + std::unordered_map> local_deduced_context; + + for (auto& parameter : function_declaration->parameters) { + local_deduced_context[parameter->type] = std::nullopt; + } + + // TODO: do this in context manager ?? + if (!context_manager_.GetAnyValue(current_type_)->DeduceContext(*context_manager_.GetAnyValue(argument_type), local_deduced_context, *context_manager_.GetValueManager())) { + error_handling::HandleTypecheckError("Can't deduce parameters (argument " + std::to_string(j + 1) + ")", node->base); + } + + for (auto& local_deduced_type : local_deduced_context) { + auto deduced_type = deduced_context[local_deduced_type.first]; + + if (!deduced_type.has_value()) { + deduced_context[local_deduced_type.first] = local_deduced_type.second; + continue; + } + + if (local_deduced_type.second.has_value() && + !context_manager_.EqualValues(local_deduced_type.second.value(), deduced_type.value())) { + error_handling::HandleTypecheckError("Different types deduced for one parameter in function call expression", node->base); + } + } + + } else { + if (!context_manager_.AddValueRequirement(current_type_, argument_type)) { + error_handling::HandleTypecheckError("Wrong argument type (argument " + std::to_string(j + 1) + ")", node->base); + } + } + + if (node->name == "some") { + error_handling::DebugPrint("7"); } ++i; @@ -887,6 +954,43 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { } } + if (node->name == "some") { + error_handling::DebugPrint("8"); + } + + if (deduce_parameters) { + for (size_t i = 0; i < function_declaration->parameters.size(); ++i) { + auto deduced_type = deduced_context[function_declaration->parameters[i]->type]; + + if (!deduced_type.has_value()) { + error_handling::HandleTypecheckError("Can't deduce parameters in function call", node->base); + } + + current_type_ = deduced_type.value(); + + if (context.count(function_declaration->parameters[i]->type) != 0) { + error_handling::HandleInternalError("Local abstract types with same name in one context", + "TypeCheckVisitor.FunctionCallExpresssion", + &node->base); + } + context[function_declaration->parameters[i]->type] = current_type_; + + context_manager_.DefineLocalType( + function_declaration->parameters[i]->type, + context_manager_.AddValue( + info::type::AbstractType(utils::AbstractTypeModifier::Abstract, + function_declaration->parameters[i]->graph_id_, + typeclass_graph_), + utils::ValueType::Tmp)); + + // TODO: check, that parameter type valid ?? + } + } + + if (node->name == "some") { + error_handling::DebugPrint("9"); + } + if (!utils::IsBuiltinFunction(node->name)) { if (node->function_id_.has_value()) { if (!global_info_.GetFunctionInfo(node->function_id_.value()).definition.has_value()) { @@ -913,12 +1017,20 @@ void TypeCheckVisitor::Visit(FunctionCallExpression* node) { } } + if (node->name == "some") { + error_handling::DebugPrint("10"); + } + Visit(function_declaration->type->types.back().second.get()); current_type_ = TypeInContext(current_type_, context); context_manager_.ExitContext(); node->base.type_ = context_manager_.ToModifiedValue(current_type_, utils::ValueType::Tmp); + + if (node->name == "some") { + error_handling::DebugPrint("11"); + } } void TypeCheckVisitor::Visit(TupleExpression* node) { diff --git a/src/types.cpp b/src/types.cpp index bc8827a..b2a8f8f 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -23,7 +23,8 @@ bool AbstractType::Require(const AbstractType& type) const { // TODO: cache Depe } bool AbstractType::DeduceContext(const AbstractType& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { return typeclass_graph_.GetDependenciesSet(graph_id_).count(actual_type.graph_id_) != 0 || graph_id_ == actual_type.graph_id_; } @@ -55,9 +56,10 @@ bool DefinedType::Require(const DefinedType& type) const { } bool DefinedType::DeduceContext(const DefinedType& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { return type_id_ == actual_type.type_id_ - && type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context); + && type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context, type_manager); } std::optional DefinedType::GetFieldType(const std::string& name, @@ -115,13 +117,14 @@ bool TupleType::Require(const TupleType& type) const { } bool TupleType::DeduceContext(const TupleType& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { if (fields_.size() != actual_type.fields_.size()) { return false; } for (size_t i = 0; i < fields_.size(); ++i) { - if (!type_manager_->GetAnyValue(fields_[i].second)->DeduceContext(*type_manager_->GetAnyValue(actual_type.fields_[i].second), context)) { + if (!type_manager_->GetAnyValue(fields_[i].second)->DeduceContext(*type_manager_->GetAnyValue(actual_type.fields_[i].second), context, type_manager)) { return false; } } @@ -215,7 +218,8 @@ bool VariantType::Require(const VariantType& type) const { } bool VariantType::DeduceContext(const VariantType& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { if (constructors_.size() != actual_type.constructors_.size()) { return false; } @@ -230,7 +234,7 @@ bool VariantType::DeduceContext(const VariantType& actual_type, } if (constructors_[i].second.has_value()) { - if (!constructors_[i].second.value().DeduceContext(actual_type.constructors_[i].second.value(), context)) { + if (!constructors_[i].second.value().DeduceContext(actual_type.constructors_[i].second.value(), context, type_manager)) { return false; } } @@ -286,8 +290,9 @@ bool OptionalType::Require(const OptionalType& type) const { } bool OptionalType::DeduceContext(const OptionalType& actual_type, - std::unordered_map>& context) const { - return type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context); + std::unordered_map>& context, + TypeManager& type_manager) const { + return type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context, type_manager); } std::optional OptionalType::GetFieldType(const std::string&, @@ -320,8 +325,9 @@ bool ReferenceToType::Require(const ReferenceToType& type) const { } bool ReferenceToType::DeduceContext(const ReferenceToType& actual_type, - std::unordered_map>& context) const { - return references_ == actual_type.references_ && type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context); + std::unordered_map>& context, + TypeManager& type_manager) const { + return references_ == actual_type.references_ && type_manager_->GetAnyValue(type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.type_), context, type_manager); } std::optional ReferenceToType::GetFieldType(const std::string& name, @@ -396,13 +402,14 @@ bool FunctionType::Require(const FunctionType& type) const { } bool FunctionType::DeduceContext(const FunctionType& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { if (argument_types_.size() != actual_type.argument_types_.size()) { return false; } for (size_t i = 0; i < argument_types_.size(); ++i) { - if (!type_manager_->GetAnyValue(argument_types_[i])->DeduceContext(*type_manager_->GetAnyValue(actual_type.argument_types_[i]), context)) { + if (!type_manager_->GetAnyValue(argument_types_[i])->DeduceContext(*type_manager_->GetAnyValue(actual_type.argument_types_[i]), context, type_manager)) { return false; } } @@ -457,8 +464,9 @@ bool ArrayType::Require(const ArrayType& type) const { } bool ArrayType::DeduceContext(const ArrayType& actual_type, - std::unordered_map>& context) const { - return size_ == actual_type.size_ && type_manager_->GetAnyValue(elements_type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.elements_type_), context); + std::unordered_map>& context, + TypeManager& type_manager) const { + return size_ == actual_type.size_ && type_manager_->GetAnyValue(elements_type_)->DeduceContext(*type_manager_->GetAnyValue(actual_type.elements_type_), context, type_manager); } std::optional ArrayType::GetFieldType(const std::string&, @@ -570,15 +578,16 @@ bool Type::Require(const Type& type) const { // TODO: check abstract type requir // TODO: check abstract type requirements for not abstract types bool Type::DeduceContext(const Type& actual_type, - std::unordered_map>& context) const { + std::unordered_map>& context, + TypeManager& type_manager) const { size_t this_index = type_.index(); size_t type_index = actual_type.type_.index(); if (this_index == 0) { std::string type_name = std::get(type_).GetName(); if (context.count(type_name) != 0) { // is abstract type - // context[type_name] = // TODO: actual_type.id_; ?? - // TODO: fixes + context[type_name] = type_manager.AddAnyValue(Type(actual_type), utils::ValueType::Tmp); + // TODO: choose value type ?? (or should be set later ??) } } @@ -588,30 +597,38 @@ bool Type::DeduceContext(const Type& actual_type, switch (this_index) { case 0: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 1: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 2: return std::get(type_) == std::get(actual_type.type_); case 3: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 4: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 5: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 6: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 7: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); case 8: return std::get(type_).DeduceContext(std::get(actual_type.type_), - context); + context, + type_manager); default: // error break; diff --git a/src/utils.cpp b/src/utils.cpp index 3254bf2..7edc862 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -85,6 +85,8 @@ bool IsBuiltinFunction(const std::string& name) { // optimize ?? builtin_functions.insert("show"); builtin_functions.insert("read"); builtin_functions.insert("error"); + builtin_functions.insert("some"); + builtin_functions.insert("none"); // builtin_functions.insert("debug_show"); // TODO return builtin_functions.count(name) != 0; diff --git a/tests/test_code.lang b/tests/test_code.lang index e963ab7..cbdc85a 100644 --- a/tests/test_code.lang +++ b/tests/test_code.lang @@ -16,6 +16,10 @@ decl random : -> \int // TODO decl error : \string -> \unit +decl some 'a : 'a -> 'a? + +decl none 'a : 'a -> 'a? + // decl not : \bool -> \bool @@ -264,24 +268,24 @@ decl scan-three : -> (& \string & \string & \string) def scan-three = & \io..scan: & \io..scan: & \io..scan: exec main { - var n = \int..read: (\io..scan:) +// var n = \int..read: (\io..scan:) +// +// if n <= 0 then error: "n can't be less then 1" +// +// // var x = (for _ in 0--n do scan-int:) +// var x = \array[int]..of: (for _ in 0--n do scan-int:) +// +// +// var k? = if n < 2 then n * 2 +. 3 in +// , print-anything:[string] "n < 2" +// , print-anything:[int] k +// +// ; print-anything:[int] n - if n <= 0 then error: "n can't be less then 1" + ; print-int-with-comment: ::i 123 (some: "comment") - // var x = (for _ in 0--n do scan-int:) - var x = \array[int]..of: (for _ in 0--n do scan-int:) - - - var k? = if n < 2 then n * 2 +. 3 in - , print-anything:[string] "n < 2" - , print-anything:[int] k - - ; print-anything:[int] n - - ; print-int-with-comment: ::i 123 - - var & a & b & c = scan-three-t: - ; \io..print: b - var & d & e & f = scan-three: - ; \io..print: e +// var & a & b & c = scan-three-t: +// ; \io..print: b +// var & d & e & f = scan-three: +// ; \io..print: e }