diff --git a/include/definitions.hpp b/include/definitions.hpp index 0d72526..9d3de5b 100644 --- a/include/definitions.hpp +++ b/include/definitions.hpp @@ -78,7 +78,8 @@ struct Function { struct Typeclass { std::vector parameters; - std::vector requirements; + std::vector function_requirements; + std::vector method_requirements; }; struct Import { diff --git a/include/interpreter_tree.hpp b/include/interpreter_tree.hpp index d2b45a7..1cedf8d 100644 --- a/include/interpreter_tree.hpp +++ b/include/interpreter_tree.hpp @@ -355,7 +355,8 @@ struct TypeclassDefinitionStatement { BaseNode base; std::unique_ptr definition; - std::vector> requirements; + std::vector> method_requirements; + std::vector> function_requirements; utils::IdType typeclass_id_; }; diff --git a/include/parse_tree.hpp b/include/parse_tree.hpp index 2cb5c33..f8b6a0d 100644 --- a/include/parse_tree.hpp +++ b/include/parse_tree.hpp @@ -71,6 +71,7 @@ public: Node NthChild(size_t n) { return Node(ts_node_child(node_, n), source_); } + size_t ChildCount() { return ts_node_child_count(node_); } @@ -86,6 +87,14 @@ public: return Node(ts_node_child_by_field_name(node_, name.c_str(), name.size()), source_); } + Node PreviousSibling() { + return Node(ts_node_prev_sibling(node_), source_); + } + + Node PreviousNamedSibling() { + return Node(ts_node_prev_named_sibling(node_), source_); + } + Node NextSibling() { return Node(ts_node_next_sibling(node_), source_); } diff --git a/lang-parser b/lang-parser index 62c8b61..3610504 160000 --- a/lang-parser +++ b/lang-parser @@ -1 +1 @@ -Subproject commit 62c8b6193437e0a10d5b605dc002c7db5e3fc256 +Subproject commit 3610504b4ce142fe30f47851709fef6bc2fc53f1 diff --git a/src/build_visitor.cpp b/src/build_visitor.cpp index e7ffee7..3d8014c 100644 --- a/src/build_visitor.cpp +++ b/src/build_visitor.cpp @@ -269,14 +269,15 @@ void BuildVisitor::Visit(TypeclassDefinitionStatement* node) { size_t child_count = parse_node.NamedChildCount(); - if (child_count > 1) { - node->requirements.resize(child_count - 1); - - for (size_t i = 0; i + 1 < child_count; ++i) { - current_node_ = parse_node.NthNamedChild(i + 1); - node->requirements[i] = std::make_unique(); - Visit(node->requirements[i].get()); - } + for (size_t i = 0; i + 1 < child_count; ++i) { + current_node_ = parse_node.NthNamedChild(i + 1); + if (parse_node.PreviousSibling().GetValue() != "var") { + node->function_requirements.push_back(std::make_unique()); + Visit(node->function_requirements.back().get()); + } else { + node->method_requirements.push_back(std::make_unique()); + Visit(node->method_requirements.back().get()); + } } current_node_ = parse_node; diff --git a/src/find_symbols_visitor.cpp b/src/find_symbols_visitor.cpp index 7b9cf47..f58bca6 100644 --- a/src/find_symbols_visitor.cpp +++ b/src/find_symbols_visitor.cpp @@ -160,10 +160,17 @@ void FindSymbolsVisitor::Visit(TypeclassDefinitionStatement* node) { current_info_.reset(); } - info.requirements.reserve(node->requirements.size()); - for (size_t i = 0; i < node->requirements.size(); ++i) { - Visit(node->requirements[i].get()); - info.requirements[i] = std::move(std::any_cast(current_info_)); + info.function_requirements.reserve(node->function_requirements.size()); + for (size_t i = 0; i < node->function_requirements.size(); ++i) { + Visit(node->function_requirements[i].get()); + info.function_requirements[i] = std::move(std::any_cast(current_info_)); + current_info_.reset(); + } + + info.method_requirements.reserve(node->method_requirements.size()); + for (size_t i = 0; i < node->method_requirements.size(); ++i) { + Visit(node->method_requirements[i].get()); + info.method_requirements[i] = std::move(std::any_cast(current_info_)); current_info_.reset(); } diff --git a/src/print_visitor.cpp b/src/print_visitor.cpp index cc24647..46b8069 100644 --- a/src/print_visitor.cpp +++ b/src/print_visitor.cpp @@ -157,14 +157,22 @@ void PrintVisitor::Visit(AbstractTypeDefinitionStatement* node) { void PrintVisitor::Visit(TypeclassDefinitionStatement* node) { out_ << "[Typeclass] ("; Visit(node->definition.get()); - if (!node->requirements.empty()) { + if (!node->function_requirements.empty()) { out_ << ") : (\n"; } - for (auto& requirement : node->requirements) { + for (auto& requirement : node->function_requirements) { out_ << "& "; Visit(requirement.get()); out_ << "\n"; } + if (!node->method_requirements.empty()) { + out_ << ") : (\n"; + } + for (auto& requirement : node->method_requirements) { + out_ << "& var "; + Visit(requirement.get()); + out_ << "\n"; + } out_ << ")\n"; } diff --git a/src/typed_print_visitor.cpp b/src/typed_print_visitor.cpp index 04efd97..a3ff3e1 100644 --- a/src/typed_print_visitor.cpp +++ b/src/typed_print_visitor.cpp @@ -218,14 +218,22 @@ void TypedPrintVisitor::Visit(TypeclassDefinitionStatement* node) { out_ << "] ("; Visit(node->definition.get()); - if (!node->requirements.empty()) { + if (!node->function_requirements.empty()) { out_ << ") : (\n"; } - for (auto& requirement : node->requirements) { + for (auto& requirement : node->function_requirements) { out_ << "& "; Visit(requirement.get()); out_ << "\n"; } + if (!node->method_requirements.empty()) { + out_ << ") : (\n"; + } + for (auto& requirement : node->method_requirements) { + out_ << "& var "; + Visit(requirement.get()); + out_ << "\n"; + } out_ << ")\n"; } diff --git a/src/visitor.cpp b/src/visitor.cpp index 064ea2e..87ee5d6 100644 --- a/src/visitor.cpp +++ b/src/visitor.cpp @@ -357,8 +357,11 @@ void Visitor::Visit(AbstractTypeDefinitionStatement* node) { void Visitor::Visit(TypeclassDefinitionStatement* node) { Visit(node->definition.get()); - for (auto& requirement : node->requirements) { - Visit(requirement.get()); + for (auto& function_requirement : node->function_requirements) { + Visit(function_requirement.get()); + } + for (auto& method_requirement : node->method_requirements) { + Visit(method_requirement.get()); } } diff --git a/tests/test_code.lang b/tests/test_code.lang index 76419cc..798f106 100644 --- a/tests/test_code.lang +++ b/tests/test_code.lang @@ -5,7 +5,7 @@ basic Char basic Bool basic Unit -// bool functions +// decl not : Bool -> Bool def not : x = @@ -36,14 +36,14 @@ def ( || ) : x y = // Eq typeclass typeclass Eq = - & ( == ) : Eq -> Bool - & ( != ) : Eq -> Bool + & var ( == ) : Eq -> Bool + & var ( != ) : Eq -> Bool namespace const Eq { - def ( != ) : x = not: (self == x) + def var ( != ) : x = not: (self == x) } -// Ord typeclass +// struct Order = | EQ @@ -51,13 +51,13 @@ struct Order = | GT typeclass (Ord : #Eq) = - & compare: Ord -> Order - & ( < ) : Ord -> Bool - & ( >= ) : Ord -> Bool - & ( > ) : Ord -> Bool - & ( <= ) : Ord -> Bool - & min : Ord -> Ord - & max : Ord -> Ord + & var compare: Ord -> Order + & var ( < ) : Ord -> Bool + & var ( >= ) : Ord -> Bool + & var ( > ) : Ord -> Bool + & var ( <= ) : Ord -> Bool + & var min : Ord -> Ord + & var max : Ord -> Ord namespace var Ord { def compare : x = @@ -75,23 +75,55 @@ namespace var Ord { // typeclass Show = - & show : -> String + & var show : -> String typeclass Read = - & read : String -> Read + & var read : String -> Read typeclass Debug = & debug : -> String // +typeclass Default = + & default : -> Default +// -// Enum typeclass +typeclass Bounded = + & min_bound : -> Bounded + & max_bound : -> Bounded + & var is_max_bound : -> Bool + & var is_min_bound : -> Bool + +// typeclass Enum = - & succ : Enum -> (Optional Enum) - & pred : Enum -> (Optional Enum) + & var succ : -> (Optional Enum) + & var pred : -> (Optional Enum) + & toEnum : Int -> Enum + & var fromEnum : -> Int + +// + + + +// // bad +// typeclass Functor 'A = +// & fmap 'B ('F : (#Functor 'B)) : ('A -> 'B) -> Functor -> 'F + +// typeclass (Iterator : #Eq) = +// & next : -> Unit +// & prev : -> Unit +// +// typeclass Iterable ('Iter : #Iterable) = +// & begin : -> 'Iter +// & end : -> 'Iter + +// + +class Slice ('Elem : ) ('Structure : (#Iterable)) = + // diff --git a/tests/typeclasses.lang b/tests/typeclasses.lang index 34b419d..4b32064 100644 --- a/tests/typeclasses.lang +++ b/tests/typeclasses.lang @@ -1,14 +1,14 @@ -typeclass Copy = - & copy : Copy -> Copy +typeclass Default = + & default : -> Copy typeclass (Ord : #Eq) = - & is_less_then : Ord -> Bool + & var is_less_then : Ord -> Bool typeclass (D : #A #B #C) 'A 'B = - & do_something : -> (& 'A & 'B) + & var do_something : -> (& 'A & 'B) typeclass E 'A = - & do_something : -> 'A + & var do_something : -> 'A decl ( == ) ('A : #Ord) : 'A -> 'A -> Bool def ( == ) : a b = a.is_equal_to: b