This commit is contained in:
2026-01-19 23:10:09 +08:00
parent 1a66d45204
commit 2839c0daff
17 changed files with 2274 additions and 46 deletions

636
src/compiler.cpp Normal file
View File

@@ -0,0 +1,636 @@
#include "compiler.h"
#include <iostream>
#include <format>
namespace camellya {
Compiler::Compiler()
: current_chunk(nullptr), scope_depth(0), had_error(false) {
}
std::shared_ptr<Chunk> Compiler::compile(const Program& program) {
current_chunk = std::make_shared<Chunk>();
had_error = false;
error_message.clear();
try {
for (const auto& stmt : program.statements) {
compile_stmt(*stmt);
}
// Emit halt instruction at the end
emit_opcode(OpCode::OP_HALT);
if (had_error) {
return nullptr;
}
return current_chunk;
} catch (const std::exception& e) {
report_error(e.what());
return nullptr;
}
}
void Compiler::compile_expr(const Expr& expr) {
if (auto* binary = dynamic_cast<const BinaryExpr*>(&expr)) {
compile_binary(*binary);
} else if (auto* unary = dynamic_cast<const UnaryExpr*>(&expr)) {
compile_unary(*unary);
} else if (auto* literal = dynamic_cast<const LiteralExpr*>(&expr)) {
compile_literal(*literal);
} else if (auto* variable = dynamic_cast<const VariableExpr*>(&expr)) {
compile_variable(*variable);
} else if (auto* assign = dynamic_cast<const AssignExpr*>(&expr)) {
compile_assign(*assign);
} else if (auto* call = dynamic_cast<const CallExpr*>(&expr)) {
compile_call(*call);
} else if (auto* get = dynamic_cast<const GetExpr*>(&expr)) {
compile_get(*get);
} else if (auto* set = dynamic_cast<const SetExpr*>(&expr)) {
compile_set(*set);
} else if (auto* index = dynamic_cast<const IndexExpr*>(&expr)) {
compile_index(*index);
} else if (auto* index_set = dynamic_cast<const IndexSetExpr*>(&expr)) {
compile_index_set(*index_set);
} else if (auto* list = dynamic_cast<const ListExpr*>(&expr)) {
compile_list(*list);
} else if (auto* map = dynamic_cast<const MapExpr*>(&expr)) {
compile_map(*map);
} else {
report_error("Unknown expression type");
}
}
void Compiler::compile_binary(const BinaryExpr& expr) {
// Special handling for logical operators (short-circuit evaluation)
if (expr.op == "and") {
compile_expr(*expr.left);
size_t end_jump = emit_jump(OpCode::OP_JUMP_IF_FALSE);
emit_opcode(OpCode::OP_POP);
compile_expr(*expr.right);
patch_jump(end_jump);
return;
}
if (expr.op == "or") {
compile_expr(*expr.left);
size_t else_jump = emit_jump(OpCode::OP_JUMP_IF_FALSE);
size_t end_jump = emit_jump(OpCode::OP_JUMP);
patch_jump(else_jump);
emit_opcode(OpCode::OP_POP);
compile_expr(*expr.right);
patch_jump(end_jump);
return;
}
// Regular binary operators
compile_expr(*expr.left);
compile_expr(*expr.right);
if (expr.op == "+") {
emit_opcode(OpCode::OP_ADD);
} else if (expr.op == "-") {
emit_opcode(OpCode::OP_SUBTRACT);
} else if (expr.op == "*") {
emit_opcode(OpCode::OP_MULTIPLY);
} else if (expr.op == "/") {
emit_opcode(OpCode::OP_DIVIDE);
} else if (expr.op == "%") {
emit_opcode(OpCode::OP_MODULO);
} else if (expr.op == "==") {
emit_opcode(OpCode::OP_EQUAL);
} else if (expr.op == "!=") {
emit_opcode(OpCode::OP_NOT_EQUAL);
} else if (expr.op == ">") {
emit_opcode(OpCode::OP_GREATER);
} else if (expr.op == ">=") {
emit_opcode(OpCode::OP_GREATER_EQUAL);
} else if (expr.op == "<") {
emit_opcode(OpCode::OP_LESS);
} else if (expr.op == "<=") {
emit_opcode(OpCode::OP_LESS_EQUAL);
} else {
report_error("Unknown binary operator: " + expr.op);
}
}
void Compiler::compile_unary(const UnaryExpr& expr) {
compile_expr(*expr.operand);
if (expr.op == "-") {
emit_opcode(OpCode::OP_NEGATE);
} else if (expr.op == "!") {
emit_opcode(OpCode::OP_NOT);
} else {
report_error("Unknown unary operator: " + expr.op);
}
}
void Compiler::compile_literal(const LiteralExpr& expr) {
std::visit([this](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, double>) {
emit_constant(std::make_shared<NumberValue>(arg));
} else if constexpr (std::is_same_v<T, std::string>) {
emit_constant(std::make_shared<StringValue>(arg));
} else if constexpr (std::is_same_v<T, bool>) {
if (arg) {
emit_opcode(OpCode::OP_TRUE);
} else {
emit_opcode(OpCode::OP_FALSE);
}
} else {
emit_opcode(OpCode::OP_NIL);
}
}, expr.value);
}
void Compiler::compile_variable(const VariableExpr& expr) {
int local = resolve_local(expr.name);
if (local != -1) {
emit_bytes(static_cast<uint8_t>(OpCode::OP_GET_LOCAL), static_cast<uint8_t>(local));
} else {
emit_bytes(static_cast<uint8_t>(OpCode::OP_GET_GLOBAL), identifier_constant(expr.name));
}
}
void Compiler::compile_assign(const AssignExpr& expr) {
compile_expr(*expr.value);
int local = resolve_local(expr.name);
if (local != -1) {
emit_bytes(static_cast<uint8_t>(OpCode::OP_SET_LOCAL), static_cast<uint8_t>(local));
} else {
emit_bytes(static_cast<uint8_t>(OpCode::OP_SET_GLOBAL), identifier_constant(expr.name));
}
}
void Compiler::compile_call(const CallExpr& expr) {
compile_expr(*expr.callee);
for (const auto& arg : expr.arguments) {
compile_expr(*arg);
}
emit_bytes(static_cast<uint8_t>(OpCode::OP_CALL),
static_cast<uint8_t>(expr.arguments.size()));
}
void Compiler::compile_get(const GetExpr& expr) {
compile_expr(*expr.object);
emit_bytes(static_cast<uint8_t>(OpCode::OP_GET_PROPERTY),
identifier_constant(expr.name));
}
void Compiler::compile_set(const SetExpr& expr) {
compile_expr(*expr.object);
compile_expr(*expr.value);
emit_bytes(static_cast<uint8_t>(OpCode::OP_SET_PROPERTY),
identifier_constant(expr.name));
}
void Compiler::compile_index(const IndexExpr& expr) {
compile_expr(*expr.object);
compile_expr(*expr.index);
emit_opcode(OpCode::OP_INDEX);
}
void Compiler::compile_index_set(const IndexSetExpr& expr) {
compile_expr(*expr.object);
compile_expr(*expr.index);
compile_expr(*expr.value);
emit_opcode(OpCode::OP_INDEX_SET);
}
void Compiler::compile_list(const ListExpr& expr) {
for (const auto& elem : expr.elements) {
compile_expr(*elem);
}
emit_bytes(static_cast<uint8_t>(OpCode::OP_BUILD_LIST),
static_cast<uint8_t>(expr.elements.size()));
}
void Compiler::compile_map(const MapExpr& expr) {
for (const auto& [key, value] : expr.pairs) {
compile_expr(*key);
compile_expr(*value);
}
emit_bytes(static_cast<uint8_t>(OpCode::OP_BUILD_MAP),
static_cast<uint8_t>(expr.pairs.size()));
}
void Compiler::compile_stmt(const Stmt& stmt) {
if (auto* expr_stmt = dynamic_cast<const ExprStmt*>(&stmt)) {
compile_expr_stmt(*expr_stmt);
} else if (auto* var_decl = dynamic_cast<const VarDecl*>(&stmt)) {
compile_var_decl(*var_decl);
} else if (auto* block = dynamic_cast<const BlockStmt*>(&stmt)) {
compile_block(*block);
} else if (auto* if_stmt = dynamic_cast<const IfStmt*>(&stmt)) {
compile_if(*if_stmt);
} else if (auto* while_stmt = dynamic_cast<const WhileStmt*>(&stmt)) {
compile_while(*while_stmt);
} else if (auto* for_stmt = dynamic_cast<const ForStmt*>(&stmt)) {
compile_for(*for_stmt);
} else if (auto* return_stmt = dynamic_cast<const ReturnStmt*>(&stmt)) {
compile_return(*return_stmt);
} else if (auto* break_stmt = dynamic_cast<const BreakStmt*>(&stmt)) {
compile_break(*break_stmt);
} else if (auto* continue_stmt = dynamic_cast<const ContinueStmt*>(&stmt)) {
compile_continue(*continue_stmt);
} else if (auto* func_decl = dynamic_cast<const FunctionDecl*>(&stmt)) {
compile_function_decl(*func_decl);
} else if (auto* class_decl = dynamic_cast<const ClassDecl*>(&stmt)) {
compile_class_decl(*class_decl);
} else {
report_error("Unknown statement type");
}
}
void Compiler::compile_expr_stmt(const ExprStmt& stmt) {
compile_expr(*stmt.expression);
emit_opcode(OpCode::OP_POP);
}
void Compiler::compile_var_decl(const VarDecl& stmt) {
if (stmt.initializer) {
compile_expr(*stmt.initializer);
} else if (!stmt.type_name.empty() && stmt.type_name != "number" &&
stmt.type_name != "string" && stmt.type_name != "bool" &&
stmt.type_name != "list" && stmt.type_name != "map") {
// It's a class type - emit code to get the class and call it
emit_bytes(static_cast<uint8_t>(OpCode::OP_GET_GLOBAL),
identifier_constant(stmt.type_name));
emit_bytes(static_cast<uint8_t>(OpCode::OP_CALL), 0); // Call with 0 arguments
} else {
emit_opcode(OpCode::OP_NIL);
}
if (scope_depth == 0) {
// Global variable
emit_bytes(static_cast<uint8_t>(OpCode::OP_DEFINE_GLOBAL),
identifier_constant(stmt.name));
} else {
// Local variable
add_local(stmt.name);
}
}
void Compiler::compile_block(const BlockStmt& stmt) {
begin_scope();
for (const auto& statement : stmt.statements) {
compile_stmt(*statement);
}
end_scope();
}
void Compiler::compile_if(const IfStmt& stmt) {
compile_expr(*stmt.condition);
size_t then_jump = emit_jump(OpCode::OP_JUMP_IF_FALSE);
emit_opcode(OpCode::OP_POP); // Pop condition if it's truthy
compile_stmt(*stmt.then_branch);
if (stmt.else_branch) {
size_t else_jump = emit_jump(OpCode::OP_JUMP);
patch_jump(then_jump);
emit_opcode(OpCode::OP_POP); // Pop condition if it's falsey
compile_stmt(*stmt.else_branch);
patch_jump(else_jump);
} else {
size_t end_jump = emit_jump(OpCode::OP_JUMP);
patch_jump(then_jump);
emit_opcode(OpCode::OP_POP); // Pop condition if it's falsey
patch_jump(end_jump);
}
}
void Compiler::compile_while(const WhileStmt& stmt) {
size_t loop_start = current_chunk->current_offset();
// Push loop info for break/continue
loops.push_back({loop_start, {}, scope_depth});
compile_expr(*stmt.condition);
size_t exit_jump = emit_jump(OpCode::OP_JUMP_IF_FALSE);
emit_opcode(OpCode::OP_POP);
compile_stmt(*stmt.body);
emit_loop(loop_start);
patch_jump(exit_jump);
emit_opcode(OpCode::OP_POP);
// Patch all break statements
for (size_t break_jump : loops.back().breaks) {
patch_jump(break_jump);
}
loops.pop_back();
}
void Compiler::compile_for(const ForStmt& stmt) {
begin_scope();
if (stmt.initializer) {
compile_stmt(*stmt.initializer);
}
size_t loop_start = current_chunk->current_offset();
// Push loop info for break/continue
loops.push_back({loop_start, {}, scope_depth});
size_t exit_jump = 0;
if (stmt.condition) {
compile_expr(*stmt.condition);
exit_jump = emit_jump(OpCode::OP_JUMP_IF_FALSE);
emit_opcode(OpCode::OP_POP);
}
// Jump over increment for first iteration
size_t body_jump = emit_jump(OpCode::OP_JUMP);
size_t increment_start = current_chunk->current_offset();
if (stmt.increment) {
compile_expr(*stmt.increment);
emit_opcode(OpCode::OP_POP);
}
emit_loop(loop_start);
// Update loop start to increment (for continue)
loops.back().start = increment_start;
patch_jump(body_jump);
compile_stmt(*stmt.body);
emit_loop(increment_start);
if (stmt.condition) {
patch_jump(exit_jump);
emit_opcode(OpCode::OP_POP);
}
// Patch all break statements
for (size_t break_jump : loops.back().breaks) {
patch_jump(break_jump);
}
loops.pop_back();
end_scope();
}
void Compiler::compile_return(const ReturnStmt& stmt) {
if (stmt.value) {
compile_expr(*stmt.value);
} else {
emit_opcode(OpCode::OP_NIL);
}
emit_opcode(OpCode::OP_RETURN);
}
void Compiler::compile_break(const BreakStmt& stmt) {
if (loops.empty()) {
report_error("Cannot use 'break' outside of a loop.");
return;
}
// Pop locals until we're at the loop's scope
for (int i = static_cast<int>(locals.size()) - 1; i >= 0; i--) {
if (locals[i].depth <= loops.back().scope_depth) {
break;
}
emit_opcode(OpCode::OP_POP);
}
// Emit jump and record it for later patching
size_t jump = emit_jump(OpCode::OP_JUMP);
loops.back().breaks.push_back(jump);
}
void Compiler::compile_continue(const ContinueStmt& stmt) {
if (loops.empty()) {
report_error("Cannot use 'continue' outside of a loop.");
return;
}
// Pop locals until we're at the loop's scope
for (int i = static_cast<int>(locals.size()) - 1; i >= 0; i--) {
if (locals[i].depth <= loops.back().scope_depth) {
break;
}
emit_opcode(OpCode::OP_POP);
}
// Jump back to loop start
emit_loop(loops.back().start);
}
void Compiler::compile_function_decl(const FunctionDecl& stmt) {
// Save current state
auto prev_chunk = current_chunk;
auto prev_locals = std::move(locals);
auto prev_scope_depth = scope_depth;
auto prev_loops = std::move(loops);
// Setup new state for function
current_chunk = std::make_shared<Chunk>();
locals.clear();
loops.clear();
scope_depth = 0;
// Add an empty local for the function itself (or 'this') at slot 0
add_local("this");
// Add parameters as locals at depth 0
for (const auto& param : stmt.parameters) {
add_local(param.second);
}
// Compile body
try {
compile_stmt(*stmt.body);
} catch (const CompileError&) {
// Error already reported
had_error = true;
}
// Ensure function returns
emit_opcode(OpCode::OP_NIL);
emit_opcode(OpCode::OP_RETURN);
auto func_chunk = current_chunk;
// Restore state
current_chunk = prev_chunk;
locals = std::move(prev_locals);
scope_depth = prev_scope_depth;
loops = std::move(prev_loops);
auto func_decl = std::make_shared<FunctionDecl>(stmt);
auto func = std::make_shared<FunctionValue>(stmt.name, func_decl, func_chunk);
emit_constant(func);
if (scope_depth == 0) {
emit_bytes(static_cast<uint8_t>(OpCode::OP_DEFINE_GLOBAL),
identifier_constant(stmt.name));
} else {
add_local(stmt.name);
}
}
void Compiler::compile_class_decl(const ClassDecl& stmt) {
// Create class value with fields and methods
auto klass = std::make_shared<ClassValue>(stmt.name);
// Add fields and methods to the class
for (const auto& member : stmt.members) {
if (auto* var_decl = dynamic_cast<VarDecl*>(member.get())) {
// Field declaration
klass->add_field(var_decl->name, var_decl->type_name);
} else if (auto* func_decl = dynamic_cast<FunctionDecl*>(member.get())) {
// Method declaration - compile to bytecode
auto prev_chunk = current_chunk;
auto prev_locals = std::move(locals);
auto prev_scope_depth = scope_depth;
auto prev_loops = std::move(loops);
current_chunk = std::make_shared<Chunk>();
locals.clear();
loops.clear();
scope_depth = 0;
add_local("this");
for (const auto& param : func_decl->parameters) {
add_local(param.second);
}
try {
compile_stmt(*func_decl->body);
} catch (const CompileError&) {
had_error = true;
}
if (func_decl->name == "init") {
emit_bytes(static_cast<uint8_t>(OpCode::OP_GET_LOCAL), 0);
} else {
emit_opcode(OpCode::OP_NIL);
}
emit_opcode(OpCode::OP_RETURN);
auto method_chunk = current_chunk;
current_chunk = prev_chunk;
locals = std::move(prev_locals);
scope_depth = prev_scope_depth;
loops = std::move(prev_loops);
auto func_decl_ptr = std::make_shared<FunctionDecl>(*func_decl);
auto func = std::make_shared<FunctionValue>(func_decl->name, func_decl_ptr, method_chunk);
klass->add_method(func_decl->name, func);
}
}
// Push the class as a constant and define it as a global
emit_constant(klass);
emit_bytes(static_cast<uint8_t>(OpCode::OP_DEFINE_GLOBAL),
identifier_constant(stmt.name));
}
// Helper methods
void Compiler::emit_byte(uint8_t byte) {
current_chunk->write(byte, 0); // Line number tracking could be improved
}
void Compiler::emit_opcode(OpCode op) {
emit_byte(static_cast<uint8_t>(op));
}
void Compiler::emit_bytes(uint8_t byte1, uint8_t byte2) {
emit_byte(byte1);
emit_byte(byte2);
}
void Compiler::emit_constant(ValuePtr value) {
emit_bytes(static_cast<uint8_t>(OpCode::OP_CONSTANT), make_constant(value));
}
size_t Compiler::emit_jump(OpCode op) {
emit_opcode(op);
emit_byte(0xff);
emit_byte(0xff);
return current_chunk->current_offset() - 2;
}
void Compiler::patch_jump(size_t offset) {
current_chunk->patch_jump(offset);
}
void Compiler::emit_loop(size_t loop_start) {
emit_opcode(OpCode::OP_LOOP);
size_t offset = current_chunk->current_offset() - loop_start + 2;
if (offset > UINT16_MAX) {
report_error("Loop body too large.");
return;
}
emit_byte((offset >> 8) & 0xff);
emit_byte(offset & 0xff);
}
void Compiler::begin_scope() {
scope_depth++;
}
void Compiler::end_scope() {
scope_depth--;
// Pop all local variables in this scope
while (!locals.empty() && locals.back().depth > scope_depth) {
emit_opcode(OpCode::OP_POP);
locals.pop_back();
}
}
void Compiler::add_local(const std::string& name) {
if (locals.size() >= UINT8_MAX) {
report_error("Too many local variables in scope.");
return;
}
locals.push_back({name, scope_depth, false});
}
int Compiler::resolve_local(const std::string& name) {
for (int i = static_cast<int>(locals.size()) - 1; i >= 0; i--) {
if (locals[i].name == name) {
return i;
}
}
return -1;
}
uint8_t Compiler::make_constant(ValuePtr value) {
size_t constant = current_chunk->add_constant(value);
if (constant > UINT8_MAX) {
report_error("Too many constants in one chunk.");
return 0;
}
return static_cast<uint8_t>(constant);
}
uint8_t Compiler::identifier_constant(const std::string& name) {
return make_constant(std::make_shared<StringValue>(name));
}
void Compiler::report_error(const std::string& message) {
had_error = true;
error_message = message;
throw CompileError(message);
}
} // namespace camellya