diff --git a/byterun/src/compiler.cpp b/byterun/src/compiler.cpp index d9cc2d5a3..122218d0f 100644 --- a/byterun/src/compiler.cpp +++ b/byterun/src/compiler.cpp @@ -42,14 +42,52 @@ std::vector transform(std::vector v, const std::function &f) { return result; } +template void insert(std::vector &x, std::vector &&y) { + x.insert(x.end(), std::move_iterator(y.begin()), std::move_iterator(y.end())); +} + +template std::vector reverse(std::vector &&x) { + std::reverse(x.begin(), x.end()); + return std::move(x); +} + template std::vector concat(std::vector &&x) { return std::move(x); } -template -std::vector concat(std::vector &&x, std::vector &&y, Args &&...args) { - x.insert(x.end(), std::move_iterator(y.begin()), std::move_iterator(y.end())); - return concat(std::move(x), std::forward(args)...); +// --> declarations +template + requires std::is_same_v +std::vector concat(std::vector &&x, U &&y, Args &&...args); +template + requires std::is_same_v> +std::vector concat(std::vector &&x, U &&y, Args &&...args); +template + requires std::is_same_v> +std::vector concat(std::vector &&x, U &&y, Args &&...args); +// <-- + +template + requires std::is_same_v> +std::vector concat(std::vector &&x, U &&y, Args &&...args) { + insert(x, std::move(y)); + return concat(std::move(x), std::forward(args)...); +} + +template + requires std::is_same_v +std::vector concat(std::vector &&x, U &&y, Args &&...args) { + x.push_back(std::move(y)); + return concat(std::move(x), std::forward(args)...); +} + +template + requires std::is_same_v> +std::vector concat(std::vector &&x, U &&y, Args &&...args) { + if (y) { + x.push_back(std::move(*y)); + } + return concat(std::move(x), std::forward(args)...); } // template @@ -167,6 +205,7 @@ const auto r12 = Register::from_number(12); const auto r13 = Register::from_number(13); const auto r14 = Register::from_number(14); const auto r15 = Register::from_number(15); + const std::vector argument_registers = {rdi, rsi, rdx, rcx, r8, r9}; const std::vector extra_caller_saved_registers = {r10, r11, r12, @@ -276,6 +315,8 @@ using M = Opnd::M; using R = Opnd::R; using S = Opnd::S; +/* Value that could be used to fill unused stack locations. + Garbage is not allowed as it will affect GC. */ struct ArgumentLocation { struct Register { Opnd opnd; @@ -466,13 +507,13 @@ struct ValT { std::string s; }; struct Local { - int n; + size_t n; }; struct Arg { - int n; + size_t n; }; struct Access { - int n; + size_t n; }; struct Fun { std::string s; @@ -550,6 +591,10 @@ template struct AbstractSymbolicStack { const W &operator*() const { return val; } const W &operator->() const { return val; } + + template bool is() const { + return std::holds_alternative(val); + } }; using Stack = SymbolicLocation::Stack; using Register = SymbolicLocation::Register; @@ -665,8 +710,9 @@ struct SymbolicStack { using R = AbSS::StackState::R; using E = AbSS::StackState::E; - using Stack = AbSS::SymbolicLocation::Stack; - using Register = AbSS::SymbolicLocation::Register; + using SymbolicLocation = AbSS::SymbolicLocation; + using Stack = SymbolicLocation::Stack; + using Register = SymbolicLocation::Register; // type t @@ -694,7 +740,7 @@ struct SymbolicStack { }; } - static Opnd opnd_from_loc(const T &v, const AbSS::SymbolicLocation &loc) { + static Opnd opnd_from_loc(const T &v, const SymbolicLocation &loc) { return std::visit( utils::multifunc{ [](const Register &x) -> Opnd { return {Opnd::R{x.r}}; }, @@ -900,7 +946,8 @@ public: /* is rdx register in use */ bool rdx_in_use() const { return nargs > 2; } - std::vector arguments_locations(size_t n) { + std::pair, size_t> + arguments_locations(size_t n) { // TODO // if n < argument_registers_size then // ( Array.to_list (Array.sub argument_registers 0 n) @@ -1421,7 +1468,7 @@ std::vector compile_binop(Env &env, Opr op) { stack. As we do not have control where does the C compiler locate them in the moment of GC, we have to explicitly locate them on the stack. And to the runtime function we are passing a reference to their location. */ -const std::vector safepoint_functions = { +const std::unordered_set safepoint_functions = { utils::labeled("s__Infix_58"), utils::labeled("substring"), utils::labeled("clone"), utils::labeled_builtin("string"), utils::labeled("stringcat"), utils::labeled("string"), @@ -1434,16 +1481,261 @@ const std::vector safepoint_functions = { /* Lsprintf, or Bsprintf is an extra dirty hack that probably works */ }; -const std::vector> vararg_functions = { +const std::unordered_map vararg_functions = { {utils::labeled("printf"), 1}, {utils::labeled("fprintf"), 2}, {utils::labeled("sprintf"), 1}, {utils::labeled("failure"), 1}, }; +namespace utils::call_compilation::tail { + +// NOTE: all comands in result are in inversed order +void push_args_rec_inv(Env &env, std::vector &acc, + size_t n) { + if (n == 0) { + return; + } + const auto x = env.pop(); + utils::insert(acc, utils::reverse(mov(x, env.loc(ValT::Arg{n - 1})))); + push_args_rec_inv(env, acc, n - 1); +} +std::vector push_args(Env &env, size_t n) { + std::vector acc; + push_args_rec_inv(env, acc, n); + std::reverse(acc.begin(), acc.end()); + return acc; +} + +} // namespace utils::call_compilation::tail +std::vector compile_tail_call(Env &env, + const std::optional &fname, + size_t nargs) { + using namespace utils::call_compilation::tail; + std::vector pushs = push_args(env, nargs); + + std::optional setup_closure; + if (!fname) { + const auto closure = env.pop(); + setup_closure = Mov{closure, r15}; + } + + Instr add_argc_counter = Mov{L{static_cast(nargs)}, r11}; + + Instr jump = fname ? Instr{Jmp{*fname}} : Instr{JmpI{r15}}; + + env.allocate(); + return utils::concat(std::move(pushs), Instr{Mov{rbp, rsp}}, Instr{Pop{rbp}}, + std::move(setup_closure), std::move(add_argc_counter), + std::move(jump)); +} + +namespace utils::call_compilation { + +std::vector pop_arguments(Env &env, size_t n) { + std::vector result; + result.reserve(n); + for (size_t i = 0; i < n; ++i) { + const auto x = env.pop(); + result.push_back(x); + } + std::reverse(result.begin(), result.end()); + return result; +}; + +namespace common { + +std::pair> setup_arguments(Env &env, + size_t nargs) { + const auto move_arguments = + [](std::vector &&args, + std::vector &&arg_locs) { + using Register = SymbolicStack::Register; + using Stack = SymbolicStack::Stack; + + assert(args.size() == arg_locs.size()); + + std::vector result; + result.reserve(args.size()); + // NOTE: direction should be (fold left) + for (size_t i = 0; i < args.size(); ++i) { + result.push_back( + arg_locs[i].is() + ? Instr{Mov{args[i], std::get(*arg_locs[i]).r}} + : /*Stack*/ Push{args[i]}); + } + std::reverse(result.begin(), result.end()); + return result; + }; + auto args = pop_arguments(env, nargs); + auto [arg_locs, stack_slots] = env.arguments_locations(args.size()); + auto setup_args_code = move_arguments(std::move(args), std::move(arg_locs)); + return {stack_slots, std::move(setup_args_code)}; +} + +std::optional setup_closure(Env &env, + const std::optional &fname) { + if (!fname) { + return {}; + } + const auto closure = env.pop(); + return Mov{closure, r15}; +} + +Instr call(const std::optional &fname) { + return fname ? Instr{Call{*fname}} : Instr{CallI{r15}}; +} + +Instr add_argc_counter(const std::optional &fname, size_t nargs) { + const auto it = + fname ? vararg_functions.find(*fname) : vararg_functions.end(); + size_t argc = it == vararg_functions.end() ? 0 : it->second; + return Mov{L{static_cast(nargs - argc)}, r11}; +} + +} // namespace common + +std::pair, std::vector> +protect_registers(Env &env) { + std::vector pushr; + std::vector popr; + if (env.has_closure) { + pushr.push_back(Push{r15}); + popr.push_back(Pop{r15}); + } + + pushr = utils::concat( + std::move(pushr), + utils::transform(env.live_registers(), + [](const auto &r) { return Push{r}; })); + popr = utils::concat( + std::move(popr), + utils::transform( + env.live_registers(), [](const auto &r) -> Instr { return Pop{r}; })); + + return {pushr, popr}; +} + +std::pair, std::optional> +align_stack(size_t saved_registers, size_t stack_arguments) { + const bool aligned = (saved_registers + stack_arguments) % 2 == 0; + if (aligned && stack_arguments == 0) { + return {{}, {}}; + } + if (aligned) { + return {{}, + {Binop{Opr::ADD, L{static_cast(word_size * stack_arguments)}, + rsp}}}; + } + + return {Push{filler}, + {Binop{Opr::ADD, + L{static_cast(word_size * (1 + stack_arguments))}, rsp}}}; +} + +Instr move_result(Env &env) { + const auto y = env.allocate(); + return Mov{rax, y}; +} + +} // namespace utils::call_compilation +std::vector compile_common_call(Env &env, + const std::optional &fname, + size_t nargs) { + using namespace utils::call_compilation::common; + using namespace utils::call_compilation; + + auto add_argc_counter_code = add_argc_counter(fname, nargs); + + auto [stack_slots, setup_args_code] = setup_arguments(env, nargs); + auto [push_registers, pop_registers] = protect_registers(env); + auto [align_prologue, align_epilogue] = + align_stack(push_registers.size(), stack_slots); + auto setup_closure_code = setup_closure(env, fname); + auto call_code = call(fname); + auto move_result_code = move_result(env); + + return utils::concat( + std::move(push_registers), std::move(align_prologue), + std::move(setup_args_code), std::move(setup_closure_code), + std::move(add_argc_counter_code), std::move(call_code), + std::move(align_epilogue), utils::reverse(std::move(pop_registers)), + std::move(move_result_code)); +} + +namespace utils::call_compilation::safepoint { + +std::pair> +setup_arguments(Env &env, const std::optional &fname, + size_t nargs) { + auto args = pop_arguments(env, nargs); + auto [arg_locs, stack_slots] = env.arguments_locations(args.size()); + auto setup_args_code = + utils::transform(utils::reverse(std::move(args)), + [](const auto &arg) { return Push{arg}; }); + setup_args_code.push_back(Mov{rsp, rdi}); + if (*fname == utils::labeled_builtin("closure")) { + setup_args_code.push_back(Mov{L{box(nargs - 1)}, rsi}); + } else if (*fname == utils::labeled_builtin("sexp") || + *fname == utils::labeled_builtin("array")) { + setup_args_code.push_back(Mov{L{box(nargs)}, rsi}); + } + return {nargs, std::move(setup_args_code)}; +} + +Instr call(const std::optional &fname) { return Call{*fname}; } + +} // namespace utils::call_compilation::safepoint +std::vector +compile_safepoint_call(Env &env, + const std::optional &fname, size_t nargs) { + using namespace utils::call_compilation::safepoint; + using namespace utils::call_compilation; + + auto [stack_slots, setup_args_code] = setup_arguments(env, fname, nargs); + auto [push_registers, pop_registers] = protect_registers(env); + auto [align_prologue, align_epilogue] = + align_stack(push_registers.size(), stack_slots); + auto call_code = call(fname); + auto move_result_code = move_result(env); + + return utils::concat(std::move(push_registers), std::move(align_prologue), + std::move(setup_args_code), std::move(call_code), + std::move(align_epilogue), + utils::reverse(std::move(pop_registers)), + std::move(move_result_code)); +} + std::vector compile_call(Env &env, - std::optional fname, - size_t nargs, bool tail) {} + std::optional fname_in, + size_t nargs, bool tail) { + std::optional fname; + if (fname_in) { + fname = (*fname_in)[0] == '.' ? utils::labeled_builtin(fname->substr(1)) + : std::string{*fname_in}; + } + + bool safepoint_call = false; + bool allowed_function = true; + if (fname) { + safepoint_call = (safepoint_functions.count(*fname) != 0); + const bool is_vararg = (vararg_functions.count(*fname) != 0); + const bool is_internal = ((*fname)[0] == 'B'); + allowed_function = not is_internal && not is_vararg; + } + + const bool same_arguments_count = env.nargs == nargs; + const bool tail_call_optimization_applicable = + tail && allowed_function && same_arguments_count; + + if (safepoint_call) { + return compile_safepoint_call(env, fname, nargs); + } + if (tail_call_optimization_applicable) { + return compile_tail_call(env, fname, nargs); + } + return compile_common_call(env, fname, nargs); +} enum class Patt { BOXED, @@ -1860,4 +2152,10 @@ std::vector compile(cmd, Env &env, std::vector compile(cmd, Env &env, const std::vector &imports, - const std::vector &code) {} + const std::vector &code) { + std::vector result; + for (const auto &instr : code) { + result = + utils::concat(std::move(result), compile(cmd, env, imports, instr)); + } +}