From 7f4266821c8346edd420269613022a05e6d810fc Mon Sep 17 00:00:00 2001 From: ProgramSnail Date: Tue, 23 May 2023 11:54:15 +0300 Subject: [PATCH] fixes, new examples --- include/interpreter_tree.hpp | 2 +- include/type_check_visitor.hpp | 20 +++- include/types.hpp | 24 +++-- lang-parser | 2 +- src/build_visitor.cpp | 6 +- src/execute_visitor.cpp | 4 +- src/main.cpp | 20 +++- src/print_visitor.cpp | 2 +- src/type_check_visitor.cpp | 36 ++++--- src/typed_print_visitor.cpp | 2 +- src/types.cpp | 18 ++-- src/visitor.cpp | 2 +- tests/arrays.lang | 18 ++-- tests/stdlib.lang | 188 +++++++++++++++++++++++++++++++++ tests/test_code.lang | 52 ++++++--- 15 files changed, 322 insertions(+), 74 deletions(-) diff --git a/include/interpreter_tree.hpp b/include/interpreter_tree.hpp index 26e22bd..4f1baf8 100644 --- a/include/interpreter_tree.hpp +++ b/include/interpreter_tree.hpp @@ -498,7 +498,7 @@ struct ReferenceExpression { BaseNode base; utils::ReferenceModifier reference; - std::unique_ptr expression; + SubExpressionToken expression; }; struct AccessExpression { diff --git a/include/type_check_visitor.hpp b/include/type_check_visitor.hpp index 2746a8b..69e3556 100644 --- a/include/type_check_visitor.hpp +++ b/include/type_check_visitor.hpp @@ -210,7 +210,10 @@ private: } void VisitDefinedType(info::definition::AnyType* defined_type, - const std::unordered_map& context) { + const std::unordered_map& context, + utils::ValueType modifier) { + context_manager_.EnterContext(); + AddTypeParameterLocalTypes(defined_type); Visitor::Visit(defined_type->node->value); current_type_ = TypeInContext(current_type_, context); current_type_ = @@ -218,7 +221,20 @@ private: current_type_, defined_type->modifier, context_manager_.GetValueManager()), - utils::ValueType::Tmp); + modifier); + context_manager_.ExitContext(); + } + + void AddTypeParameterLocalTypes(info::definition::AnyType* type_info) { + for (auto& parameter : type_info->node->definition->parameters) { + context_manager_.DefineLocalType( + parameter->type, + context_manager_.AddValue( + info::type::AbstractType(utils::AbstractTypeModifier::Abstract, + parameter->graph_id_, + typeclass_graph_), + utils::ValueType::Tmp)); + } } private: diff --git a/include/types.hpp b/include/types.hpp index 3756856..61d4b76 100644 --- a/include/types.hpp +++ b/include/types.hpp @@ -43,7 +43,7 @@ public: return graph_id == graph_id_ || typeclass_graph_.GetDependenciesSet(graph_id_).count(graph_id) != 0; } - std::string ToString() { + std::string ToString() const { return "Abstract " + std::to_string(graph_id_); } private: @@ -81,7 +81,7 @@ public: return class_modifier_; } - std::string ToString() { + std::string ToString() const { return "Defined"; } private: @@ -178,7 +178,7 @@ public: return fields_; } - std::string ToString(); + std::string ToString() const; private: std::optional name_; std::vector, utils::IdType>> fields_; @@ -209,7 +209,7 @@ public: current_constructor_ = constructor; } - std::string ToString(); + std::string ToString() const; private: std::optional name_; std::vector constructors_; @@ -231,7 +231,7 @@ public: std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; - std::string ToString(); + std::string ToString() const; private: utils::IdType type_; TypeManager* type_manager_ = nullptr; @@ -243,7 +243,11 @@ public: ReferenceToType(const std::vector& references, utils::IdType type, TypeManager* type_manager) - : references_(references), type_(type), type_manager_(type_manager) {} + : references_(references), type_(type), type_manager_(type_manager) { + if (references.empty()) { + error_handling::HandleInternalError("ReferenceToType with 0 references", "Type.ReferenceToType", std::nullopt); + } + } std::optional InContext(const std::unordered_map& context); bool Same(const ReferenceToType& type) const; @@ -253,7 +257,7 @@ public: std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; - std::string ToString(); + std::string ToString() const; private: std::vector references_; utils::IdType type_; @@ -278,7 +282,7 @@ public: std::optional GetFieldType(const std::string& name, const std::unordered_set& type_namespaces) const; - std::string ToString(); + std::string ToString() const; private: std::vector argument_types_; utils::IdType return_type_; @@ -305,7 +309,7 @@ public: return elements_type_; } - std::string ToString(); + std::string ToString() const; private: size_t size_; // = 0 for dynamic utils::IdType elements_type_; @@ -339,7 +343,7 @@ public: return type_; } - std::string ToString(); + std::string ToString() const; private: std::variantexpression = std::make_unique(); - Visit(node->expression.get()); + Visit(node->expression); current_node_ = parse_node; } @@ -1053,7 +1052,8 @@ void BuildVisitor::Visit(TypeConstructorParameter* node) { size_t child_count = parse_node.NamedChildCount(); if (child_count > 1) { - node->name = parse_node.ChildByFieldName("name").GetValue(); + current_node_ = parse_node.ChildByFieldName("name"); + node->name = current_node_.GetValue(); std::string assignment_modifier = current_node_.NextSibling().GetValue(); if (assignment_modifier == "=") { diff --git a/src/execute_visitor.cpp b/src/execute_visitor.cpp index 9f0e74a..a033d0f 100644 --- a/src/execute_visitor.cpp +++ b/src/execute_visitor.cpp @@ -386,7 +386,7 @@ void ExecuteVisitor::Visit(LoopControlExpression& node) { void ExecuteVisitor::Visit(ReferenceExpression* node) { // TODO: check, that there is no references to "Tmp"?? - Visit(node->expression.get()); + Visitor::Visit(node->expression); utils::ValueType value_type = context_manager_.GetValueType(current_value_); @@ -748,7 +748,7 @@ void ExecuteVisitor::Visit(TupleName* node) { std::optional maybe_tuple_value = context_manager_.GetValue(value); - if (maybe_tuple_value.has_value()) { + if (!maybe_tuple_value.has_value()) { error_handling::HandleRuntimeError("Mismatched value types in tuple variable definition", node->base); } diff --git a/src/main.cpp b/src/main.cpp index 6b4a7d0..40123fe 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -59,9 +59,17 @@ int main(int argc, char** argv) { // TODO, only test version // std::cout << "\n---------------------------------- Untyped -------------------------------------\n\n"; // print_visitor.VisitSourceFile(source_file.get()); - find_symbols_visitor.VisitSourceFile(source_file.get()); - link_symbols_visitor.VisitSourceFile(source_file.get()); - type_check_visitor.VisitSourceFile(source_file.get()); + try { + find_symbols_visitor.VisitSourceFile(source_file.get()); + } catch (...) { error_handling::HandleInternalError("find_symbols_visitor exception", "main", std::nullopt); } + + try { + link_symbols_visitor.VisitSourceFile(source_file.get()); + } catch (...) { error_handling::HandleInternalError("link_symbols_visitor exception", "main", std::nullopt); } + + try { + type_check_visitor.VisitSourceFile(source_file.get()); + } catch (...) { error_handling::HandleInternalError("type_check_visitor exception", "main", std::nullopt); } std::optional maybe_main_partition_id = global_info.FindPartition({"main"}); @@ -72,13 +80,15 @@ int main(int argc, char** argv) { // TODO, only test version const info::GlobalInfo::PartitionInfo& main_partition = global_info.GetPartitionInfo(maybe_main_partition_id.value()); - std::cout << "\n---------------------------------- Execution -------------------------------------\n\n"; + // std::cout << "\n---------------------------------- Execution -------------------------------------\n\n"; interpreter::ExecuteVisitor execute_visitor(global_info, type_context_manager, context_manager); - execute_visitor.ExecutePartition(main_partition.node); + try { + execute_visitor.ExecutePartition(main_partition.node); + } catch (...) { error_handling::HandleInternalError("execute_visitor exception", "main", std::nullopt); } // std::cout << "\n---------------------------------- Typed -------------------------------------\n\n"; // typed_print_visitor.VisitSourceFile(source_file.get()); diff --git a/src/print_visitor.cpp b/src/print_visitor.cpp index 5fb9861..b4df798 100644 --- a/src/print_visitor.cpp +++ b/src/print_visitor.cpp @@ -398,7 +398,7 @@ void PrintVisitor::Visit(ReferenceExpression* node) { break; } out_ << "] ("; - Visit(node->expression.get()); + Visitor::Visit(node->expression); out_ << ')'; } diff --git a/src/type_check_visitor.cpp b/src/type_check_visitor.cpp index c01835c..b36592e 100644 --- a/src/type_check_visitor.cpp +++ b/src/type_check_visitor.cpp @@ -58,13 +58,11 @@ void TypeCheckVisitor::Visit(Namespace* node) { info::definition::AnyType* type_info = maybe_type_info.value(); - Visitor::Visit(type_info->node->value); - utils::IdType type = context_manager_.AddValue( - info::type::DefinedType(node->link_type_id_.value(), - current_type_, - type_info->modifier, - context_manager_.GetValueManager()), - ClassInternalsModifierToValueType(node->modifier)); + // make parameter local types + + + VisitDefinedType(type_info, {}, ClassInternalsModifierToValueType(node->modifier)); + utils::IdType type = current_type_; if (node->modifier != utils::ClassInternalsModifier::Static) { context_manager_.DefineVariable(utils::ClassInternalVarName, type); @@ -248,7 +246,9 @@ void TypeCheckVisitor::Visit(FunctionDefinitionStatement* node) { if (!returned_type_.has_value()) { returned_type_ = current_type_; } + if (!context_manager_.EqualValues(returned_type, returned_type_.value())) { + // error_handling::DebugPrint(context_manager_.GetAnyValue(returned_type)->ToString() + " : " + context_manager_.GetAnyValue(returned_type_.value())->ToString()); error_handling::HandleTypecheckError("Wrong function return type", node->base); } returned_type_ = std::nullopt; @@ -435,9 +435,7 @@ void TypeCheckVisitor::Visit(TypeConstructorPattern* node) { // TODO: match name } } - Visitor::Visit(type_info.node->value); - current_type_ = TypeInContext(current_type_, context); - current_type_ = context_manager_.ToModifiedValue(current_type_, utils::ValueType::Tmp); + VisitDefinedType(&type_info, context, utils::ValueType::Tmp); node->base.type_ = current_type_; } @@ -696,7 +694,7 @@ void TypeCheckVisitor::Visit(LoopControlExpression&) { // enum // Operators void TypeCheckVisitor::Visit(ReferenceExpression* node) { - Visit(node->expression.get()); + Visitor::Visit(node->expression); current_type_ = context_manager_.AddValue( info::type::ReferenceToType({node->reference}, @@ -983,6 +981,8 @@ void TypeCheckVisitor::Visit(TypeConstructor* node) { } } + // TODO: replace with VisitDefinedType ?? + AddTypeParameterLocalTypes(&type_info); Visitor::Visit(type_info.node->value); std::optional maybe_variant_type = @@ -1293,7 +1293,7 @@ void TypeCheckVisitor::Visit(TypeExpression* node) { &node->base); } - VisitDefinedType(maybe_type_info.value(), context); + VisitDefinedType(maybe_type_info.value(), context, utils::ValueType::Tmp); } else { error_handling::HandleTypecheckError("Type not found", node->base); } @@ -1316,9 +1316,11 @@ void TypeCheckVisitor::Visit(TypeExpression* node) { void TypeCheckVisitor::Visit(ExtendedScopedAnyType* node) { Visitor::Visit(node->type); - current_type_ = context_manager_.AddValue( - info::type::ReferenceToType(node->references, current_type_, context_manager_.GetValueManager()), - utils::ValueType::Tmp); + if (!node->references.empty()) { + current_type_ = context_manager_.AddValue( + info::type::ReferenceToType(node->references, current_type_, context_manager_.GetValueManager()), + utils::ValueType::Tmp); + } node->base.type_ = current_type_; } @@ -1540,7 +1542,7 @@ std::optional &node->base); } - VisitDefinedType(maybe_type_info.value(), context); + VisitDefinedType(maybe_type_info.value(), context, utils::ValueType::Tmp); maybe_function_declaration = FindDefinedTypeFunctionAndUpdate(node, maybe_type_info.value(), @@ -1580,7 +1582,7 @@ std::optional TypeCheckVisitor::FindFunctionAndUpdate(Func &node->base); } - VisitDefinedType(maybe_type_info.value(), {}); // TODO: context ?? + VisitDefinedType(maybe_type_info.value(), {}, utils::ValueType::Tmp); // TODO: context ?? maybe_function_declaration = FindDefinedTypeFunctionAndUpdate(node, maybe_type_info.value(), diff --git a/src/typed_print_visitor.cpp b/src/typed_print_visitor.cpp index 0a7204e..29d7c84 100644 --- a/src/typed_print_visitor.cpp +++ b/src/typed_print_visitor.cpp @@ -557,7 +557,7 @@ void TypedPrintVisitor::Visit(ReferenceExpression* node) { break; } out_ << "] ("; - Visit(node->expression.get()); + Visitor::Visit(node->expression); out_ << ')'; } diff --git a/src/types.cpp b/src/types.cpp index 9bf19af..93f085f 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -120,13 +120,13 @@ std::optional TupleType::GetFieldType(const std::string& name, return std::nullopt; } -std::string TupleType::ToString() { +std::string TupleType::ToString() const { std::string result; result += "("; for (auto& field : fields_) { - result += "& "; + result += " & "; result += type_manager_->GetAnyValue(field.second)->ToString(); } @@ -185,7 +185,7 @@ std::optional VariantType::GetFieldType(const std::string& name, return std::nullopt; } -std::string VariantType::ToString() { +std::string VariantType::ToString() const { std::string result; result += "("; @@ -229,7 +229,7 @@ std::optional OptionalType::GetFieldType(const std::string&, return std::nullopt; } -std::string OptionalType::ToString() { +std::string OptionalType::ToString() const { return "Optional " + type_manager_->GetAnyValue(type_)->ToString(); } @@ -263,7 +263,7 @@ std::optional ReferenceToType::GetFieldType(const std::string& na } -std::string ReferenceToType::ToString() { +std::string ReferenceToType::ToString() const { std::string result; for (auto& reference : references_) { @@ -337,7 +337,7 @@ std::optional FunctionType::GetFieldType(const std::string&, return std::nullopt; } -std::string FunctionType::ToString() { +std::string FunctionType::ToString() const { std::string result; result += "("; @@ -387,7 +387,7 @@ std::optional ArrayType::GetFieldType(const std::string&, return std::nullopt; } -std::string ArrayType::ToString() { +std::string ArrayType::ToString() const { return "Array (" + std::to_string(size_) + ") " + type_manager_->GetAnyValue(elements_type_)->ToString(); } @@ -554,7 +554,7 @@ std::string Type::GetTypeName() const { return ""; // ?? } -std::string Type::ToString() { +std::string Type::ToString() const { size_t index = type_.index(); switch (index) { @@ -562,7 +562,7 @@ std::string Type::ToString() { return std::get(type_).ToString(); case 1: return std::get(type_).ToString(); - case 2: + case 2: // ?? return ::info::type::ToString(std::get(type_)); case 3: return std::get(type_).ToString(); diff --git a/src/visitor.cpp b/src/visitor.cpp index 86450a7..ce268ef 100644 --- a/src/visitor.cpp +++ b/src/visitor.cpp @@ -466,7 +466,7 @@ void Visitor::Visit(LoopControlExpression&) {} // enum // Operators void Visitor::Visit(ReferenceExpression* node) { - Visit(node->expression.get()); + Visit(node->expression); } void Visitor::Visit(AccessExpression* node) { diff --git a/tests/arrays.lang b/tests/arrays.lang index 734d713..23974ce 100644 --- a/tests/arrays.lang +++ b/tests/arrays.lang @@ -9,13 +9,13 @@ def test_arrays = { var arr6 <- String._new_array: 10 var arr6_reference = ^arr6 - - const elem1 = arr1`0 - var elem2 = arr1`2 - const ref1 = ^arr1`1 - var ref2 = ^arr1`3 - ; arr1`1 = 123 - - ; ~ref1 = arr1`2 // set value - ; ref1 = ref2 // set pointer / reference +// +// const elem1 = arr1`0 +// var elem2 = arr1`2 +// const ref1 = ^arr1`1 +// var ref2 = ^arr1`3 +// ; arr1`1 = 123 +// +// ; ~ref1 = arr1`2 // set value +// ; ref1 = ref2 // set pointer / reference } diff --git a/tests/stdlib.lang b/tests/stdlib.lang index e69de29..2950e11 100644 --- a/tests/stdlib.lang +++ b/tests/stdlib.lang @@ -0,0 +1,188 @@ +basic (Float : #Ord #Div #Str) +basic (Int : #Ord #IDiv #Str) +basic (String : #Ord #Str #CharContainer #Copy) +basic (Char : #Ord #Str #Copy) +basic (Bool : #Ord #Str #Copy) +basic (Unit : #Str #Copy) + +// + +decl not : Bool -> Bool +def not : x = + (match x with + | true -> false + | false -> true) + +decl ( && ) : Bool -> Bool -> Bool +def ( && ) : x y = + match x with + | true -> ( + match y with + | true -> true + | false -> false + ) + | false -> false + +decl ( || ) : Bool -> Bool -> Bool +def ( || ) : x y = + match x with + | true -> true + | false -> ( + match y with + | true -> true + | false -> false + ) + +// + +typeclass CharContainer = + & var size : -> Int + & var at : Int -> Char + +// + +typeclass Move = // TODO + & var ( <- ) : Move -> Unit + +typeclass Copy = + & var ( = ) : Copy -> Unit + +// + +typeclass (Sum : #Copy) = + & var ( += ) : Sum -> Unit + & var ( -= ) : Sum -> Unit + & var ( + ) : Sum -> Sum + & var ( - ) : Sum -> Sum + & zero : -> Sum + +namespace var Sum { + def ( + ) : x = { + var ans = self + ; ans += x + return ans + } + + def ( - ) : x = { + var ans = self + ; ans -= x + return ans + } +} + +typeclass (Mult : #Sum) = + & var ( *= ) : Mult -> Unit + & var ( * ) : Mult -> Mult + +namespace var Mult { + def ( * ) : x = { + var ans = self + ; ans *= x + return ans + } +} + +typeclass (IDiv : #Mult) = + & var div : IDiv -> IDiv + & var mod : IDiv -> IDiv + +namespace var IDiv { + def mod : x = self -. x * self.div: x +} + +typeclass (Div : #Mult) = + & var ( /= ) : Div -> Unit + & var ( / ) : Div -> Div + +namespace var Div { + def ( / ) : x = { + var ans = self + ; ans /= x + return ans + } +} + +// + +typeclass Eq = + & var ( == ) : Eq -> Bool + & var ( != ) : Eq -> Bool + +namespace var Eq { + def ( != ) : x = not: (self == x) +} + +// + +struct Order = + | EQ + | LT + | GT + +typeclass (Ord : #Eq) = + & var compare : Ord -> Order + & var ( < ) : Ord -> Bool + & var ( >= ) : Ord -> Bool + & var ( > ) : Ord -> Bool + & var ( <= ) : Ord -> Bool + +decl min ('A : #Ord) : 'A -> 'A -> 'A +def min : x y = if x < y then x else y + +decl max ('A : #Ord) : 'A -> 'A -> 'A +def max : x y = if x < y then y else x + +namespace var Ord { + def compare : x = + if self == x then $EQ + elif self < x then $LT + else $GT + + def ( >= ) : x = not: (self < x) + def ( > ) : x = x < self + def ( <= ) : x = not: (x < self) +} + +// + +typeclass Show = + & var show : -> String + +typeclass Read = + & read : String -> Read + +typeclass (Str : #Show #Read) + +// typeclass DebugShow = // TODO +// & debug_show : -> String + +// + +typeclass Default = + & default : -> Default + +// + +typeclass Bounded = + & min_bound : -> Bounded + & max_bound : -> Bounded + & var is_max_bound : -> Bool + & var is_min_bound : -> Bool + +// + +typeclass Enum = + & var succ : -> (Optional Enum) + & var pred : -> (Optional Enum) + & to_enum : Int -> Enum + & var from_enum : -> Int + +// + +namespace IO { + decl print : String -> Unit + decl scan : -> String + decl random : -> Int // TODO +} + +// diff --git a/tests/test_code.lang b/tests/test_code.lang index 01a3951..2ce0496 100644 --- a/tests/test_code.lang +++ b/tests/test_code.lang @@ -41,8 +41,8 @@ typeclass CharContainer = // -typeclass Move = - & var ( <- ) : Move -> Unit // TODO +typeclass Move = // TODO + & var ( <- ) : Move -> Unit typeclass Copy = & var ( = ) : Copy -> Unit @@ -201,9 +201,6 @@ namespace IO { // -decl ret_one : -> Int -def ret_one = 1 - decl ( -- ) : Int -> Int -> Int_0 def ( -- ) : begin end = { var current = begin @@ -242,17 +239,48 @@ def scan_anything = 'A.read: (IO.scan:) decl print_anything ('A : #Show) : 'A -> Unit def print_anything : x = IO.print: (x.show:) +// decl sorted ('A : #Ord #Copy): 'A_0 -> Int -> 'A_0 +// def sorted : a sz = { +// var a_copy = a +// if sz == 2 then { +// if a_copy`0 > a_copy`1 then { +// var x = a_copy`0 +// a_copy`0 = a_copy`1 +// a_copy`1 = x +// } +// return a_copy +// } +// +// var center = sz.div: 2 +// +// var a_left = for i in 0--center do a`i +// var a_right = for i in center-sz do a`i +// +// return a_copy +// } /* struct Array 'A = & data : 'A_0 namespace Array { - decl construct : 'A_0 -> Array - def construct: x = $(Array 'A) & data = x -} // TODO: construct decl + default def -> segfault -*/ + decl of : 'A_0 -> Array + def of: x = $(Array 'A) & data = x +}*/ + +struct ThreeTuple = & String & String & String + +decl scan_three_t : -> ThreeTuple +def scan_three_t = $ThreeTuple & IO.scan: & IO.scan: & IO.scan: + +decl scan_three : -> (& String & String & String) +def scan_three = & IO.scan: & IO.scan: & IO.scan: + +// var n = scan_anything Int: +// var a = $(Array Int) & data = (for _ in 0--n do scan_int:) +// ; print_anything Int: n exec main { - var n = scan_anything Int: - var a = for _ in 0--n do scan_int: - ; print_anything Int: n + var & a & b & c = scan_three_t: + ; IO.print: b + var & d & e & f = scan_three: + ; IO.print: e }