Skip to content

Commit

Permalink
[Constexpr] reenabled constant expression evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
isuckatcs committed Jun 28, 2024
1 parent cb5581e commit 462667e
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 64 deletions.
90 changes: 50 additions & 40 deletions include/ast.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
#ifndef A_COMPILER_AST_H
#define A_COMPILER_AST_H

#include <cassert>
#include <iostream>
#include <memory>
#include <vector>

#include "lexer.h"
#include "utils.h"

namespace {
std::string_view dumpOp(TokenKind op) {
if (op == TokenKind::Plus)
return "+";
if (op == TokenKind::Minus)
return "-";
if (op == TokenKind::Asterisk)
return "*";
if (op == TokenKind::Slash)
return "/";
if (op == TokenKind::EqualEqual)
return "==";
if (op == TokenKind::AmpAmp)
return "&&";
if (op == TokenKind::PipePipe)
return "||";
if (op == TokenKind::Lt)
return "<";
if (op == TokenKind::Gt)
return ">";
if (op == TokenKind::Excl)
return "!";

assert(false && "unexpected operator");
}
} // namespace

struct Decl : public Dumpable {
SourceLocation location;
std::string identifier;
Expand Down Expand Up @@ -168,22 +196,8 @@ struct BinaryOperator : public Expr {
: Expr(location), lhs(std::move(lhs)), rhs(std::move(rhs)), op(op) {}

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "BinaryOperator: '";
if (op == TokenKind::Plus)
std::cerr << '+';
if (op == TokenKind::Minus)
std::cerr << '-';
if (op == TokenKind::Asterisk)
std::cerr << '*';
if (op == TokenKind::Slash)
std::cerr << '/';
if (op == TokenKind::EqualEqual)
std::cerr << '=' << '=';
if (op == TokenKind::AmpAmp)
std::cerr << '&' << '&';
if (op == TokenKind::PipePipe)
std::cerr << '|' << '|';
std::cerr << '\'' << '\n';
std::cerr << indent(level) << "BinaryOperator: '" << dumpOp(op) << '\''
<< '\n';

lhs->dump(level + 1);
rhs->dump(level + 1);
Expand All @@ -199,10 +213,8 @@ struct UnaryOperator : public Expr {
: Expr(location), rhs(std::move(rhs)), op(op) {}

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "UnaryOperator: '";
if (op == TokenKind::Excl)
std::cerr << '!';
std::cerr << '\'' << '\n';
std::cerr << indent(level) << "UnaryOperator: '" << dumpOp(op) << '\''
<< '\n';

rhs->dump(level + 1);
}
Expand Down Expand Up @@ -444,6 +456,8 @@ struct ResolvedNumberLiteral : public ResolvedExpr {

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "NumberLiteral: '" << value << "'\n";
if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';
}
};

Expand All @@ -456,6 +470,8 @@ struct ResolvedDeclRefExpr : public ResolvedExpr {
void dump(size_t level = 0) const override {
std::cerr << indent(level) << "ResolvedDeclRefExpr: @(" << decl << ") "
<< decl->identifier << "\n";
if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';
}
};

Expand All @@ -471,6 +487,8 @@ struct ResolvedCallExpr : public ResolvedExpr {
void dump(size_t level = 0) const override {
std::cerr << indent(level) << "ResolvedCallExpr: @(" << callee << ") "
<< callee->identifier << "\n";
if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';

for (auto &&arg : arguments)
arg->dump(level + 1);
Expand All @@ -486,6 +504,8 @@ struct ResolvedGroupingExpr : public ResolvedExpr {

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "ResolvedGroupingExpr:\n";
if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';

expr->dump(level + 1);
}
Expand All @@ -503,22 +523,11 @@ struct ResolvedBinaryOperator : public ResolvedExpr {
rhs(std::move(rhs)), op(op) {}

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "ResolvedBinaryOperator: '";
if (op == TokenKind::Plus)
std::cerr << '+';
if (op == TokenKind::Minus)
std::cerr << '-';
if (op == TokenKind::Asterisk)
std::cerr << '*';
if (op == TokenKind::Slash)
std::cerr << '/';
if (op == TokenKind::EqualEqual)
std::cerr << '=' << '=';
if (op == TokenKind::AmpAmp)
std::cerr << '&' << '&';
if (op == TokenKind::PipePipe)
std::cerr << '|' << '|';
std::cerr << '\'' << '\n';
std::cerr << indent(level) << "ResolvedBinaryOperator: '" << dumpOp(op)
<< '\'' << '\n';

if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';

lhs->dump(level + 1);
rhs->dump(level + 1);
Expand All @@ -534,10 +543,11 @@ struct ResolvedUnaryOperator : public ResolvedExpr {
: ResolvedExpr(location, rhs->type), rhs(std::move(rhs)), op(op) {}

void dump(size_t level = 0) const override {
std::cerr << indent(level) << "ResolvedUnaryOperator: '";
if (op == TokenKind::Excl)
std::cerr << '!';
std::cerr << '\'' << '\n';
std::cerr << indent(level) << "ResolvedUnaryOperator: '" << dumpOp(op)
<< '\'' << '\n';

if (auto val = getConstantValue())
std::cerr << indent(level) << "| value: " << *val << '\n';

rhs->dump(level + 1);
}
Expand Down
2 changes: 2 additions & 0 deletions include/constexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
class ConstantExpressionEvaluator {
std::optional<double>
evaluateBinaryOperator(const ResolvedBinaryOperator &binop);
std::optional<double> evaluateUnaryOperator(const ResolvedUnaryOperator &op);
std::optional<double> evaluateDeclRefExpr(const ResolvedDeclRefExpr &dre);

public:
std::optional<double> evaluate(const ResolvedExpr &expr);
Expand Down
64 changes: 58 additions & 6 deletions src/constexpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,78 @@
#include <cassert>
#include <optional>

namespace {
bool toBool(double d) { return d != 0.0; }
} // namespace

std::optional<double> ConstantExpressionEvaluator::evaluateBinaryOperator(
const ResolvedBinaryOperator &binop) {
std::optional<double> lhs = evaluate(*binop.lhs);
if (!lhs)
return std::nullopt;

// If the LHS of || is true, we don't need to evaluate the RHS.
if (binop.op == TokenKind::PipePipe && toBool(*lhs))
return 1.0;

// If the LHS of && if false, we don't need to evaluate the RHS.
if (binop.op == TokenKind::AmpAmp && !toBool(*lhs))
return 0.0;

std::optional<double> rhs = evaluate(*binop.rhs);
if (!rhs)
return std::nullopt;

switch (binop.op) {
case TokenKind::Plus:
return *lhs + *rhs;
case TokenKind::Minus:
return *lhs - *rhs;
case TokenKind::Asterisk:
return *lhs * *rhs;
case TokenKind::Slash:
return *lhs / *rhs;
case TokenKind::Plus:
return *lhs + *rhs;
case TokenKind::Minus:
return *lhs - *rhs;
case TokenKind::Lt:
return *lhs < *rhs;
case TokenKind::Gt:
return *lhs > *rhs;
case TokenKind::EqualEqual:
return *lhs == *rhs;
case TokenKind::AmpAmp:
case TokenKind::PipePipe:
return toBool(*rhs); // The LHS is already handled.
default:
assert(false && "unexpected binary operator");
}
}

std::optional<double> ConstantExpressionEvaluator::evaluateUnaryOperator(
const ResolvedUnaryOperator &op) {
std::optional<double> rhs = evaluate(*op.rhs);
if (!rhs)
return std::nullopt;

if (op.op == TokenKind::Excl)
return !toBool(*rhs);

assert(false && "unexpected unary operator");
}

std::optional<double> ConstantExpressionEvaluator::evaluateDeclRefExpr(
const ResolvedDeclRefExpr &dre) {
// We only care about reference to immutable variables with an initializer.
const auto *rvd = dynamic_cast<const ResolvedVarDecl *>(dre.decl);
if (!rvd || rvd->isMutable || !rvd->initializer)
return std::nullopt;

return evaluate(*rvd->initializer);
}

std::optional<double>
ConstantExpressionEvaluator::evaluate(const ResolvedExpr &expr) {
// FIXME: reenable this
return std::nullopt;
// Don't evaluate the same expression multiple times.
if (std::optional<double> val = expr.getConstantValue())
return val;

if (const auto *numberLiteral =
dynamic_cast<const ResolvedNumberLiteral *>(&expr))
Expand All @@ -44,5 +88,13 @@ ConstantExpressionEvaluator::evaluate(const ResolvedExpr &expr) {
dynamic_cast<const ResolvedBinaryOperator *>(&expr))
return evaluateBinaryOperator(*binaryOperator);

if (const auto *unaryOperator =
dynamic_cast<const ResolvedUnaryOperator *>(&expr))
return evaluateUnaryOperator(*unaryOperator);

if (const auto *declRefExpr =
dynamic_cast<const ResolvedDeclRefExpr *>(&expr))
return evaluateDeclRefExpr(*declRefExpr);

return std::nullopt;
}
4 changes: 2 additions & 2 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ std::unique_ptr<VarDecl> TheParser::parseVarDecl(bool isLet) {
varOrReturn(type, parseType());

if (nextToken.kind != TokenKind::Equal)
return std::make_unique<VarDecl>(location, identifier, *type, isLet);
return std::make_unique<VarDecl>(location, identifier, *type, !isLet);
eatNextToken(); // eat '='

varOrReturn(initializer, parseExpr());

return std::make_unique<VarDecl>(location, identifier, *type, isLet,
return std::make_unique<VarDecl>(location, identifier, *type, !isLet,
std::move(initializer));
}

Expand Down
40 changes: 24 additions & 16 deletions src/sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,26 +249,34 @@ Sema::resolveReturnStmt(const ReturnStmt &returnStmt) {
}

std::unique_ptr<ResolvedExpr> Sema::resolveExpr(const Expr &expr) {
if (const auto *numberLiteral = dynamic_cast<const NumberLiteral *>(&expr))
return std::make_unique<ResolvedNumberLiteral>(
numberLiteral->location, std::stod(numberLiteral->value));

if (const auto *declRefExpr = dynamic_cast<const DeclRefExpr *>(&expr))
return resolveDeclRefExpr(*declRefExpr);

if (const auto *callExpr = dynamic_cast<const CallExpr *>(&expr))
return resolveCallExpr(*callExpr);

if (const auto *groupingExpr = dynamic_cast<const GroupingExpr *>(&expr))
return resolveGroupingExpr(*groupingExpr);
std::unique_ptr<ResolvedExpr> resolvedExpr = nullptr;

if (const auto *binaryOperator = dynamic_cast<const BinaryOperator *>(&expr))
return resolveBinaryOperator(*binaryOperator);
if (const auto *numberLiteral = dynamic_cast<const NumberLiteral *>(&expr))
resolvedExpr = std::make_unique<ResolvedNumberLiteral>(
numberLiteral->location, std::stod(numberLiteral->value));
else if (const auto *declRefExpr = dynamic_cast<const DeclRefExpr *>(&expr))
resolvedExpr = resolveDeclRefExpr(*declRefExpr);
else if (const auto *callExpr = dynamic_cast<const CallExpr *>(&expr))
resolvedExpr = resolveCallExpr(*callExpr);
else if (const auto *groupingExpr = dynamic_cast<const GroupingExpr *>(&expr))
resolvedExpr = resolveGroupingExpr(*groupingExpr);
else if (const auto *binaryOperator =
dynamic_cast<const BinaryOperator *>(&expr))
resolvedExpr = resolveBinaryOperator(*binaryOperator);
else if (const auto *unaryOperator =
dynamic_cast<const UnaryOperator *>(&expr))
resolvedExpr = resolveUnaryOperator(*unaryOperator);
else
assert(false && "unexpected expression");

if (!resolvedExpr)
return nullptr;

if (const auto *unaryOperator = dynamic_cast<const UnaryOperator *>(&expr))
return resolveUnaryOperator(*unaryOperator);
if (std::optional<double> val = cee.evaluate(*resolvedExpr))
resolvedExpr->setConstantValue(val);

return nullptr;
return resolvedExpr;
}

std::unique_ptr<ResolvedBlock> Sema::resolveBlock(const Block &block) {
Expand Down
46 changes: 46 additions & 0 deletions test/constexpr/grouping.al
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: compiler %s -res-dump 2>&1 | filecheck %s
fn foo(): number {
let x: number = 2.1;
let y: number = 5.3;

return (10.0 * (x + 4.0)) && (!(y == x) || x < y);
}
// CHECK: ResolvedReturnStmt
// CHECK-NEXT: ResolvedBinaryOperator: '&&'
// CHECK-NEXT: | value: 1
// CHECK-NEXT: ResolvedGroupingExpr:
// CHECK-NEXT: | value: 61
// CHECK-NEXT: ResolvedBinaryOperator: '*'
// CHECK-NEXT: | value: 61
// CHECK-NEXT: NumberLiteral: '10'
// CHECK-NEXT: | value: 10
// CHECK-NEXT: ResolvedGroupingExpr:
// CHECK-NEXT: | value: 6.1
// CHECK-NEXT: ResolvedBinaryOperator: '+'
// CHECK-NEXT: | value: 6.1
// CHECK-NEXT: ResolvedDeclRefExpr: @({{.*}}) x
// CHECK-NEXT: | value: 2.1
// CHECK-NEXT: NumberLiteral: '4'
// CHECK-NEXT: | value: 4
// CHECK-NEXT: ResolvedGroupingExpr:
// CHECK-NEXT: | value: 1
// CHECK-NEXT: ResolvedBinaryOperator: '||'
// CHECK-NEXT: | value: 1
// CHECK-NEXT: ResolvedUnaryOperator: '!'
// CHECK-NEXT: | value: 1
// CHECK-NEXT: ResolvedGroupingExpr:
// CHECK-NEXT: | value: 0
// CHECK-NEXT: ResolvedBinaryOperator: '=='
// CHECK-NEXT: | value: 0
// CHECK-NEXT: ResolvedDeclRefExpr: @({{.*}}) y
// CHECK-NEXT: | value: 5.3
// CHECK-NEXT: ResolvedDeclRefExpr: @({{.*}}) x
// CHECK-NEXT: | value: 2.1
// CHECK-NEXT: ResolvedBinaryOperator: '<'
// CHECK-NEXT: | value: 1
// CHECK-NEXT: ResolvedDeclRefExpr: @({{.*}}) x
// CHECK-NEXT: | value: 2.1
// CHECK-NEXT: ResolvedDeclRefExpr: @({{.*}}) y
// CHECK-NEXT: | value: 5.3

fn main(): void {}
Loading

0 comments on commit 462667e

Please sign in to comment.