Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Alkahest

Alkahest is a high-performance computer algebra system with a Rust core and Python API.

What it is

A general-purpose symbolic math library designed around three axes:

Performance. The Rust kernel uses hash-consed directed acyclic graphs so structural equality is a pointer comparison and subexpression sharing is automatic. FLINT backs polynomial arithmetic. An LLVM JIT compiles symbolic expressions to native or GPU code at runtime. Common operations run orders of magnitude faster than SymPy.

Correctness. Every simplification and transformation produces a derivation log — the exact sequence of rewrite rules applied, with arguments and side conditions. A subset of operations can export Lean 4 proof terms verifiable by an independent checker.

Ergonomics. The Python API uses operator overloading for natural expression construction. Results are rich objects with .value, .steps, and .certificate attributes. Error messages carry structured codes, location information, and suggested remediations.

Design principles

Explicit representations. The type system distinguishes UniPoly (FLINT-backed univariate polynomial), MultiPoly (sparse multivariate), RationalFunction, and the generic Expr tree. Converting between them is an explicit call. There are no silent representation changes hiding performance cliffs.

Stateless by design. No global assumption contexts. No hidden caches that change behavior. All context (domains, simplification policy, precision) is passed explicitly or bundled into expression structure. This makes results deterministic and parallelism safe.

Composable transformations. trace, grad, jit, and certify operate on a shared traced representation and stack freely: jit(grad(f)) compiles a derivative, jit(grad(grad(f))) compiles a second derivative.

A small primitive set. Each primitive (sin, exp, add, mul, …) registers a full bundle: simplification rule, forward- and reverse-mode differentiation, MLIR lowering, Lean theorem tag, numerical evaluation. New operations are added by registering a primitive, not by adding code paths across the system.

Compared to alternatives

SymPySageMathSymbolics.jlAlkahest
PerformanceSlowModerateFastFast
GPU codegenNoNoNoYes
Lean proofsNoNoNoYes
Python APIYesYesNo (Julia)Yes
Open sourceYesYesYesYes

This guide

The guide covers the Rust-level design concepts. For the Python API reference see the Sphinx docs alongside this guide.

Getting started

Install

Alkahest is built with maturin.

pip install maturin
git clone https://github.com/alkahest/alkahest
cd alkahest
maturin develop --release

To enable optional features:

# LLVM JIT for native compiled evaluation
maturin develop --release --features jit

# E-graph simplification (egglog)
maturin develop --release --features egraph

# Parallel simplification (sharded ExprPool)
maturin develop --release --features parallel

# Gröbner basis solver
maturin develop --release --features groebner

# CUDA / NVPTX codegen (requires CUDA toolkit and LLVM with NVPTX target)
maturin develop --release --features cuda

# Full build (all optional features except cuda/rocm)
maturin develop --release --features "jit egraph parallel groebner"

First steps

Every computation starts with an ExprPool. It owns all expressions; you create symbols and integers from it.

import alkahest
from alkahest import ExprPool, diff, simplify, integrate, sin, exp, cos

pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")

Building expressions

Python operators build expression trees:

expr = x**2 + pool.integer(2) * x + pool.integer(1)
print(expr)  # x^2 + 2*x + 1

Math functions accept expressions:

f = sin(x**2) + exp(x * y)

Parsing expressions from strings

Use parse when the expression comes from user input or a config file:

from alkahest import parse

e = parse("x^2 + 2*x + 1", pool, {"x": x})
print(e)   # x^2 + 2*x + 1

Identifiers not in the symbols dict are auto-created as symbols in pool. Both ^ and ** denote exponentiation. See Parsing from strings for the full syntax reference.

Simplification

r = simplify(x + pool.integer(0))
print(r.value)  # x
print(r.steps)  # [RewriteStep(rule='add_zero', ...)]

Differentiation

dr = diff(sin(x**2), x)
print(dr.value)  # 2*x*cos(x^2)

Integration

r = integrate(exp(x), x)
print(r.value)   # exp(x)

r = integrate(sin(x), x)
print(r.value)   # -cos(x)

Polynomial arithmetic

from alkahest import UniPoly, RationalFunction

# Convert to FLINT-backed univariate polynomial
p = UniPoly.from_symbolic(x**3 + pool.integer(-1), x)
q = UniPoly.from_symbolic(x + pool.integer(-1), x)
print(p.gcd(q))          # x - 1
print(p // q)            # x^2 + x + 1

Compiled evaluation

from alkahest import compile_expr, eval_expr

# Scalar evaluation via a dict binding
result = eval_expr(x**2 + y, {x: 3.0, y: 1.0})
print(result)  # 10.0

# JIT-compiled callable
f = compile_expr(x**2 + pool.integer(1), [x])
print(f([3.0]))  # 10.0

Vectorized evaluation over NumPy arrays

import numpy as np
from alkahest import compile_expr, numpy_eval

f = compile_expr(sin(x) * exp(pool.integer(-1) * x), [x])
xs = np.linspace(0, 10, 1_000_000)
ys = numpy_eval(f, xs)  # vectorised; much faster than a Python loop

Context manager

with alkahest.context(pool=pool, simplify=True):
    z = alkahest.symbol("z")  # uses the active pool
    expr = z**2 + alkahest.sin(z)

Running the examples

The examples/ directory has runnable end-to-end scripts:

PYTHONPATH=python python examples/calculus.py
PYTHONPATH=python python examples/polynomials.py
PYTHONPATH=python python examples/jit_eval.py
PYTHONPATH=python python examples/ball_arithmetic.py
PYTHONPATH=python python examples/ode_modeling.py

Kernel design

The expression kernel is the foundation everything else builds on. It lives in alkahest-core/src/kernel/.

Hash-consed DAG

Every expression is represented as a directed acyclic graph stored in an ExprPool. Nodes are interned: before inserting a new node, the pool checks whether a structurally identical node already exists. If it does, the existing ExprId is returned instead of allocating a new node.

This gives three properties:

  1. Structural equality is a pointer comparison. id_a == id_b iff the expressions are structurally identical. No tree traversal required.
  2. Automatic subexpression sharing. If sin(x²) appears in ten different expressions, there is only one sin(x²) node in memory.
  3. Hash-based memoization is cheap. Caching the result of a transformation keyed by ExprId is O(1) and correct.

ExprPool

ExprPool is the intern table. It owns all expressions in a session.

pool = ExprPool()
x = pool.symbol("x")       # intern a Symbol node
n = pool.integer(42)       # intern an Integer node

Multiple pools are independent. An ExprId from one pool must not be mixed into another — the pool validates this in debug builds.

Persistent pool (V1-14). A pool can be serialized to disk and reopened, preserving all ExprIds across sessions:

pool.save_to("session.alkp")
pool2 = ExprPool.load_from("session.alkp")

Sharded pool. With --features parallel, the intern table uses a sharded concurrent hashmap (DashMap), allowing multiple threads to insert expressions without contention.

ExprData variants

Each interned node is one of:

VariantDescription
Symbol(name, domain)Named variable with a domain annotation
Integer(n)Exact arbitrary-precision integer
Rational(p, q)Exact rational number
Add(children)N-ary addition
Mul(children)N-ary multiplication
Pow(base, exp)Exponentiation
Call(primitive, args)Application of a registered primitive
Piecewise(cases)Conditional expression
Predicate(kind, args)Boolean condition (inequality, equality)

Add and Mul are n-ary: a + b + c is one Add node with three children, not two nested Add nodes. Children are sorted at construction time so that commutativity is structural — a + b and b + a produce the same interned node.

Domains

Every symbol carries a domain as part of its structural identity:

x_real = pool.symbol("x", "real")
x_complex = pool.symbol("x", "complex")
# x_real and x_complex are distinct expressions — different ExprIds

The domain is not a global assumption; it is part of what the symbol is. Simplification rules can query a symbol’s domain to decide whether a rewrite is valid (e.g. sqrt(x²) → x requires x to be non-negative).

Available domains: real, positive, nonnegative, integer, complex. The default when no domain is specified is real.

ExprId and memory

ExprId is a 32-bit index into the pool’s internal arena. It is Copy, Send, and Sync. Cloning an ExprId is free. No reference counting is needed because the pool owns all nodes; expressions are not freed until the pool is dropped.

The kernel is designed with parallelism as a first-class property. All kernel types are Send + Sync. The simplification and differentiation passes can run concurrently on disjoint ExprIds from the same pool.

Interning cost model

Interning a new node requires:

  1. Hash the ExprData.
  2. Look up in the concurrent hash map.
  3. On miss: allocate the node in the arena and insert into the map.
  4. On hit: return the existing ExprId.

Step 4 (the common case in a running computation) is a single hash lookup plus a pointer load. The arena uses bump allocation, so step 3 is also fast.

The memory benchmark group in alkahest-core/benches/alkahest_bench.rs verifies that rebuilding an identical expression tree does not grow the pool.

Expression representations

Alkahest exposes multiple representation types rather than hiding everything behind a single Expr. This is a deliberate design decision: the representation is visible, conversions are explicit, and performance characteristics are predictable.

Choosing a representation

If you need…Use
General symbolic computationExpr
Fast univariate polynomial arithmeticUniPoly
Sparse multivariate polynomial algebraMultiPoly
Rational functions with automatic cancellationRationalFunction
Rigorous enclosures with error boundsArbBall

Conversion to a specialized type is always an explicit opt-in:

expr = x**3 + pool.integer(-2) * x + pool.integer(1)
p = UniPoly.from_symbolic(expr, x)   # explicit conversion

If the expression cannot be represented in the target type (e.g. sin(x) as a polynomial), a ConversionError is raised with a remediation hint.

Expr

The generic symbolic expression. All other types convert to and from Expr. Built by operator overloading on the Python side:

expr = x**2 + pool.integer(3) * x * y - pool.integer(1)

Operations like diff, simplify, and integrate work on Expr and return DerivedResult objects wrapping an Expr.

UniPoly

Dense univariate polynomial backed by FLINT. Coefficients are exact integers or rationals stored in a FLINT polynomial object.

from alkahest import UniPoly

# x^3 - 2x + 1
p = UniPoly.from_symbolic(x**3 + pool.integer(-2) * x + pool.integer(1), x)

print(p.degree())        # 3
print(p.coefficients())  # [1, -2, 0, 1]  (constant first)
print(p.leading_coeff()) # 1

# Arithmetic — all FLINT-backed, exact
q = UniPoly.from_symbolic(x + pool.integer(-1), x)
print(p * q)             # x^4 - x^3 - 2x^2 + 3x - 1
print(p.gcd(q))          # x - 1
print(p // q)            # x^2 + x - 1
print(p % q)             # 0

# Powers
r = UniPoly.from_symbolic(x + pool.integer(1), x)
print(r ** 3)            # x^3 + 3x^2 + 3x + 1

UniPoly is the right choice when you are doing heavy univariate polynomial arithmetic (GCD chains, resultants, factorization) because FLINT applies highly optimized algorithms with exact arithmetic.

MultiPoly

Sparse multivariate polynomial over ℤ (integers). Terms are stored as a map from exponent vectors to coefficients.

from alkahest import MultiPoly

# x^2*y + x*y^2 - 1
expr = x**2 * y + x * y**2 + pool.integer(-1)
mp = MultiPoly.from_symbolic(expr, [x, y])

print(mp.total_degree())     # 3
print(mp.integer_content())  # 1

# Arithmetic
mp2 = MultiPoly.from_symbolic(x * y, [x, y])
print(mp + mp2)              # x^2*y + x*y^2 + x*y - 1
print(mp * mp2)              # x^3*y^2 + x^2*y^3 - x*y

Variable order matters for the exponent-vector key. Pass variables in a consistent order when constructing MultiPoly objects that will be combined.

RationalFunction

Quotient of two MultiPoly objects, automatically reduced by their GCD.

from alkahest import RationalFunction

# (x^2 - 1) / (x - 1) → normalized to x + 1
numer = x**2 + pool.integer(-1)
denom = x + pool.integer(-1)
rf = RationalFunction.from_symbolic(numer, denom, [x])
print(rf)   # x + 1

# Arithmetic preserves the rational form
rf_x = RationalFunction.from_symbolic(x, pool.integer(1), [x])
rf_inv = RationalFunction.from_symbolic(pool.integer(1), x, [x])
print(rf_x + rf_inv)   # (x^2 + 1) / x

GCD normalization runs at construction, so every RationalFunction is in lowest terms.

ArbBall

A real interval [midpoint ± radius] backed by FLINT’s Arb library. Arithmetic on ArbBall values produces guaranteed enclosures of the true result.

from alkahest import ArbBall, interval_eval, sin

# ArbBall(midpoint, radius, precision_bits=53)
a = ArbBall(2.0, 0.5)    # [1.5, 2.5]
b = ArbBall(3.0, 0.0)    # exactly 3

print(a + b)   # [4.5, 5.5]
print(a * b)   # [4.5, 7.5]

# Evaluate a symbolic expression rigorously
pool = ExprPool()
x = pool.symbol("x")
result = interval_eval(sin(x), {x: ArbBall(1.0, 1e-10)})
print(result.lo, result.hi)   # tight enclosure of sin(1)

The output ball is guaranteed to contain the true value for any input in the input balls. This is useful for:

  • Certified numerical evaluation
  • Proving bounds on symbolic expressions
  • Verification workflows alongside Lean certificate export

See Ball arithmetic for more detail.

Converting back to Expr

All specialized types can be converted back to a generic Expr for further symbolic manipulation:

p = UniPoly.from_symbolic(x**2 + pool.integer(1), x)
expr_again = p.to_symbolic(pool)
dr = diff(expr_again, x)

Parsing expressions from strings

alkahest.parse converts a human-readable math string into an Expr node using a Pratt (top-down operator precedence) recursive-descent parser.

import alkahest
from alkahest import ExprPool, parse, diff, simplify

pool = ExprPool()
x = pool.symbol("x")

e = parse("x^2 + 2*x + 1", pool, {"x": x})
print(e)                    # x^2 + 2*x + 1

dr = diff(e, x)
print(dr.value)             # 2*x + 2

Syntax

FormMeaning
42, 3.14, 1.5e-3Integer or float literal
x, alpha, x_1Symbol (created in pool on first use)
a + b, a - bAddition / subtraction
a * b, a / bMultiplication / division
a ^ b, a ** bExponentiation (right-associative)
-a, +aUnary negation / identity
(expr)Grouping
sin(x), atan2(y, x)Function call (one or two arguments)

Whitespace (spaces, tabs, newlines) is ignored everywhere.

Operator precedence

From lowest to highest:

LevelOperators
10+ - (infix)
20* /
25Unary - +
30^ ** (right-associative)

So -x^2 parses as -(x^2), not (-x)^2, and x^2^3 parses as x^(2^3) = x^8.

Supported functions

abs, acos, asin, atan, atan2, ceil, cos, cosh, erf, erfc, exp, floor, gamma, log, round, sign, sin, sinh, sqrt, tan, tanh

The symbols map

By default, every new identifier is interned as a fresh pool.symbol(name). Pass a pre-built symbols dict to bind identifiers to existing Expr objects, or to collect the symbols that were created:

# Pre-bind x to an existing symbol
x = pool.symbol("x")
e = parse("sin(x)^2 + cos(x)^2", pool, {"x": x})

# Collect auto-created symbols after parsing
sym_map: dict = {}
e = parse("a*x^2 + b*x + c", pool, sym_map)
print(sym_map.keys())   # dict_keys(['a', 'x', 'b', 'c'])

Identifiers not in the map are created and then added to the map, so the same string name always resolves to the same Expr within a single parse call.

Error handling

parse raises ParseError (code E-PARSE-001) on any lexical or syntax error. The exception’s .span attribute gives the (start, end) byte range of the offending token, and .remediation provides a hint:

from alkahest import ParseError

try:
    parse("sin(x) @ 2", pool, {"x": x})
except ParseError as e:
    print(e)           # unexpected character '@' at offset 7
    print(e.span)      # (7, 8)

try:
    parse("zeta(x)", pool, {"x": x})
except ParseError as e:
    print(e.remediation)  # known functions: abs, acos, asin, ...

Round-trip with pretty-printing

parse is the inverse of str() for expressions built from the operators and functions listed above:

from alkahest import latex, unicode_str

e = parse("sin(x)^2 + cos(x)^2", pool, {"x": x})
print(latex(e))        # \sin\!\left(x\right)^{2} + \cos\!\left(x\right)^{2}
print(unicode_str(e))  # sin(x)² + cos(x)²

Simplification

Alkahest provides two complementary simplification engines that operate on the same expression pool.

Rule-based simplification

simplify applies a fixed set of algebraic rewrite rules until no more apply (fixpoint). It is fast, predictable, and always terminates.

from alkahest import simplify

r = simplify(x + pool.integer(0))   # → x
r = simplify(x * pool.integer(1))   # → x
r = simplify(pool.integer(2) * pool.integer(3))  # → 6  (constant folding)

The default rule set covers:

  • Identity and absorbing elements (x + 0 → x, x * 1 → x, x * 0 → 0)
  • Constant folding (integer and rational arithmetic)
  • Basic polynomial simplification (x + x → 2*x, x² * x → x³)
  • Commutativity and associativity (normalized at construction)

Domain-specific rule sets

from alkahest import simplify_trig, simplify_log_exp, simplify_expanded

# Pythagorean identity and double-angle formulas
r = simplify_trig(sin(x)**2 + cos(x)**2)  # → 1

# Logarithm and exponential cancellation (branch-cut safe)
r = simplify_log_exp(exp(log(x)))   # → x  (with positive domain side condition)

# Expand products and collect like terms
r = simplify_expanded((x + pool.integer(1))**3)

Customizing the rule set

from alkahest import simplify_with, make_rule

# Add a custom rule: sin²(x) → 1 - cos²(x)
my_rule = make_rule("sin_sq_to_cos", lhs=sin(x)**2, rhs=pool.integer(1) - cos(x)**2)
r = simplify_with(expr, rules=[my_rule])

Parallel simplification

from alkahest import simplify_par

# Simplify a list of expressions concurrently (requires --features parallel)
exprs = [x**i for i in range(100)]
results = simplify_par(exprs)

E-graph simplification

simplify_egraph uses equality saturation via egglog to explore many equivalent forms simultaneously before committing to the best one via a cost function.

from alkahest import simplify_egraph

# The e-graph can discover non-obvious equivalences
r = simplify_egraph(x * x - pool.integer(1))  # may factor or simplify

E-graph saturation is more powerful than rule-based simplification for some inputs but slower and has non-deterministic performance for complex expressions. See E-graph saturation for configuration options.

Choosing between the two

Criterionsimplifysimplify_egraph
SpeedFast, predictableSlower, variable
CompletenessFixed rule setEquality saturation
TerminationAlwaysConfigurable limits
Side conditionsRespectedRespected
Best forHot paths, cleanupDifficult equalities

For most workflows: use simplify (or a domain-specific variant) first. Reach for simplify_egraph when you need the system to discover a non-obvious equivalence.

Collect and normalize

Two utility passes that sit between the two engines:

from alkahest import collect_like_terms, poly_normal

# 2*x + 3*x → 5*x
r = collect_like_terms(pool.integer(2) * x + pool.integer(3) * x)

# Normalize to canonical polynomial form over given variables
r = poly_normal(x**2 + pool.integer(2) * x * y + y**2, [x, y])

Rule engine

The rule engine underlies both simplify and the e-graph backend. Rules are the atomic units of algebraic knowledge.

Anatomy of a rule

A RewriteRule has:

  • A name — stable string identifier (used in derivation logs and Lean certificate output)
  • A LHS pattern — an expression template with pattern variables
  • A RHS template — the replacement
  • Optional side conditions — predicates that must hold for the rule to fire

Pattern syntax

Patterns are regular expressions with a subset of ExprData nodes used as wildcards. From the Python side, pattern variables are Expr objects whose names start with ?:

from alkahest import make_rule, match_pattern

pool = ExprPool()
x = pool.symbol("x")

# Pattern variable — matches any subexpression
pv = pool.symbol("?a")

# Rule: ?a + 0 → ?a
add_zero = make_rule("add_zero", lhs=pv + pool.integer(0), rhs=pv)

Pattern variables capture any subexpression and must bind consistently: if ?a appears twice in the LHS it must match the same expression in both positions.

Matching

match_pattern applies a pattern to an expression and returns all match substitutions:

matches = match_pattern(sin(x)**2 + cos(x)**2, pattern)
for subst in matches:
    print(subst)  # dict mapping pattern variable → matched expr

The matcher is associative-commutative (AC): a + b matches b + a, and a + b + c matches any ordering.

Built-in rule sets

The rule sets loaded by simplify and the domain-specific simplifiers are:

FunctionRules
simplifyArithmetic identities, constant folding, polynomial normalization
simplify_trigPythagorean identity, double-angle and half-angle formulas
simplify_log_expLog/exp cancellation (branch-cut safe subset)
simplify_expandedDistributive expansion, like-term collection

Defining custom rules

from alkahest import make_rule, simplify_with

pool = ExprPool()
x = pool.symbol("x")
a = pool.symbol("?a")
b = pool.symbol("?b")

# Commutativity of subtraction rewrite: a - a → 0
self_cancel = make_rule(
    "self_cancel",
    lhs=a + pool.integer(-1) * a,
    rhs=pool.integer(0),
)

# Apply the custom rule alongside the default set
r = simplify_with(expr, rules=[self_cancel])

Custom rules are recorded in derivation logs with the name you provide. If you tag the rule with a Lean theorem name (via the Rust PrimitiveRegistry API), the corresponding step can be exported as a Lean proof term.

Rule execution model

simplify applies rules in a fixpoint loop:

  1. For each node in the expression (post-order traversal):
    • Try each rule in the rule set.
    • If a rule matches and its side conditions are satisfied, apply it, emit a RewriteStep, and restart the loop for the modified subtree.
  2. Repeat until no rule fires in a full pass.

This is an inner-outer loop strategy rather than exhaustive bottom-up application. It is fast but not complete — some sequences of rewrites require rules to be applied in a specific order. The e-graph engine removes this ordering dependency.

Side conditions

A rule can carry a side condition checked against the matched substitution:

# sqrt(x^2) → x  only when x is non-negative
sqrt_sq = make_rule(
    "sqrt_sq_nonneg",
    lhs=sqrt(a**2),
    rhs=a,
    condition="nonnegative",   # checked against the domain of ?a
)

Side conditions that reference symbol domains are sound: sqrt_sq_nonneg will only fire when ?a is bound to a symbol with domain positive or nonneg. They propagate into the derivation log as SideCondition entries and into Lean output as assumptions.

E-graph saturation

The e-graph backend exposes a fundamentally different approach to simplification: rather than applying rules one at a time in a fixed order, it builds a structure that represents many equivalent expressions simultaneously, then extracts the best one.

What is an e-graph?

An e-graph partitions expressions into equivalence classes (e-classes). When a rewrite rule fires, it does not replace the LHS — it adds the RHS to the same e-class as the LHS. At the end of saturation, an extraction step picks the cheapest representative from each e-class according to a cost function.

This eliminates the phase-ordering problem: rules can fire in any order without risk of committing to a suboptimal form. The e-graph remembers all explored forms and chooses among them at the end.

Using the e-graph

from alkahest import simplify_egraph, simplify_egraph_with

# Default configuration
r = simplify_egraph(expr)

# With explicit config
from alkahest import EgraphConfig  # (Rust API; Python config dict in future)
r = simplify_egraph_with(expr, {"node_limit": 10_000, "iter_limit": 20})

Cost functions

The extraction step minimizes a cost function over e-class representatives. Three built-in cost functions:

NameBehavior
SizeCostPrefers the expression with the fewest AST nodes
DepthCostPrefers the shallowest expression tree
OpCostAssigns per-operation costs; penalizes expensive ops
StabilityCostPenalizes patterns that cause catastrophic cancellation

StabilityCost is aware of numerical stability issues: it penalizes subtractive cancellation patterns and prefers numerically stable rearrangements.

Configuration

The e-graph runs until saturation (no new e-class merges) or until a limit is hit:

  • node_limit — maximum number of e-nodes. Once reached, saturation stops and extraction runs on the current state.
  • iter_limit — maximum number of saturation rounds.

For large or complex expressions, saturation can be expensive. The rule-based simplify is often sufficient and should be preferred on hot paths.

Rule sets in the e-graph

The e-graph uses the same RewriteRule objects as the rule-based engine. By default it loads the arithmetic rules. Domain-specific rules (trig, log/exp) are kept separate to avoid e-class explosions on expressions that do not involve those operations.

Upcoming (v1.1): The default e-graph rule set will include trig identities (sin²+cos²→1) and safe log/exp cancellation out of the box, configurable via SimplifyConfig { include_trig_rules, include_log_exp_rules }.

When e-graphs help

The e-graph is especially powerful when:

  • Multiple non-obvious rewrites must be combined in a specific order that is hard to predict.
  • The “right” form is not syntactically similar to the input (e.g. factoring followed by cancellation).
  • You want the globally cheapest form under a custom cost function, not just any simplified form.

It is less useful when:

  • The expression is already in near-canonical form and only identity cleanup is needed.
  • You need predictable performance on a hot path.
  • The expression is large and associative-commutative, where the e-graph can grow combinatorially.

AC matching in the e-graph

The egglog backend handles associativity and commutativity structurally: Add and Mul children are sorted at pool-insertion time, so there is a single canonical ordering. The e-graph does not need to enumerate permutations.

This is more efficient than classical AC-completion but requires that the canonical ordering is established at construction, which the kernel enforces.

Calculus

Alkahest supports symbolic differentiation and integration with full derivation logging.

Differentiation

diff(expr, var) computes the symbolic derivative of expr with respect to var.

from alkahest import diff, sin, cos, exp, log

pool = ExprPool()
x = pool.symbol("x")

# Polynomial
dr = diff(x**3 + pool.integer(2) * x, x)
print(dr.value)   # 3*x^2 + 2

# Chain rule
dr = diff(sin(x**2), x)
print(dr.value)   # 2*x*cos(x^2)

# Product rule
dr = diff(x * exp(x), x)
print(dr.value)   # exp(x) + x*exp(x)

# Logarithm
dr = diff(log(x**2 + pool.integer(1)), x)
print(dr.value)   # 2*x / (x^2 + 1)

Registered primitives

Every primitive in the registry has a differentiation rule. The 23 currently registered primitives include:

sin, cos, tan, asin, acos, atan, atan2, sinh, cosh, tanh, exp, log, sqrt, abs, sign, erf, erfc, gamma, floor, ceil, round, min, max

Derivation log

The DerivedResult returned by diff records every rule application:

dr = diff(sin(x**2), x)
for step in dr.steps:
    print(f"  {step['rule']:25s}  {step['before']}  →  {step['after']}")

Forward-mode automatic differentiation

diff_forward computes the derivative using forward-mode AD (dual numbers). It produces the same result as diff but through a different computational path:

from alkahest import diff, diff_forward

sym = diff(x**3, x)
fwd = diff_forward(x**3, x)
# fwd.value == sym.value

Forward mode is useful for checking that the symbolic rules agree with dual-number evaluation.

Symbolic gradient

symbolic_grad differentiates with respect to multiple variables:

from alkahest import symbolic_grad

pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")

expr = x**2 * y + sin(x * y)
grads = symbolic_grad(expr, [x, y])
# grads[0] = ∂/∂x = 2*x*y + y*cos(x*y)
# grads[1] = ∂/∂y = x^2 + x*cos(x*y)

For the traced-function gradient (composable with jit), see Transformations.

Integration

integrate(expr, var) computes the symbolic antiderivative of expr with respect to var.

from alkahest import integrate, sin, cos, exp

# Polynomials
r = integrate(x**3, x)
print(r.value)    # x^4/4

# Known functions
r = integrate(sin(x), x)
print(r.value)    # -cos(x)

r = integrate(exp(x), x)
print(r.value)    # exp(x)

r = integrate(x**pool.integer(-1), x)
print(r.value)    # log(x)

Integration rules

The integration engine applies rules from a table of known forms (Risch subset):

  • Power rule: ∫ xⁿ dx = xⁿ⁺¹/(n+1) for integer n ≠ -1
  • Logarithm: ∫ 1/x dx = log(x)
  • Exponential tower: ∫ exp(a*x + b) dx, ∫ x * exp(x) dx
  • Linear substitution: ∫ f(a*x + b) dx
  • Trigonometric: ∫ sin(x) dx, ∫ cos(x) dx
  • Standard table entries for erf, inverse trig, etc.

If no rule applies, integrate raises an IntegrationError with a remediation hint indicating what class of integrand would be needed (e.g. “algebraic extension required — see v1.1 algebraic Risch”).

Upcoming (v1.1): Algebraic-function Risch (Trager’s algorithm) will handle integrands involving sqrt(P(x)) and other algebraic extensions.

Verification

A common pattern is to verify an antiderivative by differentiating it back:

antideriv = integrate(expr, x).value
check = simplify(diff(antideriv, x).value)
# check.value should equal expr

Higher derivatives

Chain calls to diff:

d2 = diff(diff(sin(x), x).value, x)
print(d2.value)   # -sin(x)

The derivation log of the outer diff does not include the inner steps. If you need the full trace, concatenate dr1.steps + dr2.steps.

Transformations

Transformations are higher-order operations that take a function and return a new function. They compose freely and operate on traced symbolic representations.

Tracing

trace symbolically executes a Python function by replacing its inputs with symbolic variables and recording the computation as an expression DAG.

import alkahest
from alkahest import ExprPool

pool = ExprPool()

@alkahest.trace(pool)
def f(x, y):
    return x**2 + alkahest.sin(y)

print(f.expr)    # x^2 + sin(y)
print(f.symbols) # [x, y]

The decorator takes the pool as an argument. Variable names are inferred from the function signature.

Numeric evaluation

TracedFn objects are callable with numeric values:

print(f(3.0, 0.0))   # 9.0

import numpy as np
xs = np.linspace(0, 1, 1000)
ys = np.zeros(1000)
result = f(xs, ys)   # vectorised automatically

Gradient

grad differentiates a traced function symbolically with respect to all (or a subset of) its inputs:

df = alkahest.grad(f)
# df(x, y) returns [∂f/∂x, ∂f/∂y] = [2*x, cos(y)]

grads = df(1.0, 0.0)   # [2.0, 1.0]

Differentiate with respect to a subset:

df_x = alkahest.grad(f, wrt=[f.symbols[0]])  # ∂f/∂x only

JIT compilation

jit wraps a traced function in the LLVM JIT backend. The first call triggers compilation; subsequent calls run the compiled code directly.

fast_f = alkahest.jit(f)
print(fast_f(3.0, 0.0))  # 9.0, via LLVM-compiled code

Vectorised evaluation is automatic when array inputs are detected:

xs = np.linspace(0, 10, 1_000_000)
ys = np.zeros_like(xs)
result = fast_f(xs, ys)   # zero-copy batch path

Composing transformations

Transformations stack:

# Compiled gradient
fast_df = alkahest.jit(alkahest.grad(f))
grads = fast_df(xs, ys)   # compiled, vectorised gradient

# Second derivative: grad of grad
d2f = alkahest.grad(alkahest.grad(f))

Note that grad returns a GradTracedFn, not a TracedFn. jit can be applied to GradTracedFn when it wraps a single scalar output. For multi-output cases, compile each gradient expression individually with compile_expr.

trace_fn

Functional (non-decorator) version of trace:

from alkahest import trace_fn

fn = trace_fn(lambda x, y: x * alkahest.exp(y), pool)

PyTrees

Transformations work over nested data structures (lists, dicts, tuples, dataclasses). The Python layer flattens and unflattens them automatically:

from alkahest import flatten_exprs, unflatten_exprs, map_exprs

# A system of equations as a list
eqs = [x**2 + y - pool.integer(1), x - y**2 + pool.integer(1)]

# Map simplification over the list
simplified = map_exprs(simplify, eqs)

# Flatten to a list of ExprIds and the structure descriptor
flat, treedef = flatten_exprs(eqs)
restored = unflatten_exprs(flat, treedef)

This follows the JAX pytree pattern. The Rust kernel sees only flat sequences; structure reconstruction is a Python-layer concern.

Context manager

alkahest.context sets a default pool and configuration for a block:

with alkahest.context(pool=pool, simplify=True):
    z = alkahest.symbol("z")          # uses the active pool
    expr = z**2 + alkahest.sin(z)     # auto-simplified

Inside the context, alkahest.symbol(name) creates a symbol in the active pool without passing the pool explicitly. This is a convenience wrapper — the pool is still explicit at the structural level.

Code generation

Alkahest can compile symbolic expressions to fast native or GPU code. Compiled code bypasses Python entirely during evaluation.

The compilation pipeline

Expressions lower through multiple IR levels:

ExprPool (hash-consed DAG)
    ↓  e-graph extraction + canonicalization
Canonical expression form
    ↓  alkahest MLIR dialect
High-level MLIR (math-aware ops: horner, poly_eval, interval_eval)
    ↓  lowering passes
Standard MLIR (arith, math, linalg, gpu)
    ↓
LLVM IR / PTX / StableHLO (depending on target)
    ↓
Native machine code / GPU kernel / XLA

The custom alkahest MLIR dialect is where math-aware optimizations happen: Horner’s method for polynomials, fused multiply-add emission, numerically stable rearrangements via StabilityCost.

compile_expr

compile_expr produces a callable from a symbolic expression and a list of input variables:

from alkahest import ExprPool, compile_expr, sin, cos

pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")

f = compile_expr(x**2 + sin(y), [x, y])
print(f([3.0, 0.0]))   # 9.0

The callable takes a list of floats (one per variable) and returns a float. For batch evaluation see numpy_eval below.

Without --features jit, a fast Rust tree-walking interpreter is used instead of LLVM. The API is identical.

eval_expr

For one-off evaluation without compiling:

from alkahest import eval_expr

result = eval_expr(x**2 + sin(y), {x: 3.0, y: 0.0})
print(result)   # 9.0

eval_expr is slower than a compiled function for repeated evaluation but has no compilation overhead.

numpy_eval

numpy_eval vectorises a compiled function over NumPy arrays via the batch path:

import numpy as np
from alkahest import numpy_eval

f = compile_expr(sin(x) * cos(x), [x])
xs = np.linspace(0, 2 * 3.14159, 1_000_000)
ys = numpy_eval(f, xs)   # vectorised, zero-copy

Also accepts PyTorch CPU tensors and JAX arrays via DLPack.

Horner-form emission

horner rewrites a polynomial expression into Horner’s form, which is numerically better conditioned and faster to evaluate:

from alkahest import horner

# x^3 + 2x^2 + 3x + 4 → x*(x*(x + 2) + 3) + 4
h = horner(x**3 + pool.integer(2)*x**2 + pool.integer(3)*x + pool.integer(4), x)

emit_c emits a C function string for embedding in other projects:

from alkahest import emit_c

c_code = emit_c(expr, [x, y], fn_name="f")
# → "double f(double x, double y) { return ...; }"

MLIR dialect

The alkahest-mlir crate exposes the custom MLIR dialect. The dialect ops are:

OpDescription
alkahest.symSymbolic variable reference
alkahest.constConstant value
alkahest.add, alkahest.mulArithmetic
alkahest.powExponentiation
alkahest.hornerHorner polynomial evaluation
alkahest.poly_evalGeneric polynomial evaluation
alkahest.series_taylorTaylor series evaluation
alkahest.interval_evalBall arithmetic evaluation
alkahest.rational_fnRational function evaluation

Three lowering targets are available:

  • ArithMath — lowers to arith + math MLIR dialects; uses math.fma for Horner chains
  • StableHlo — lowers to StableHLO ops for XLA/JAX integration
  • Llvm — lowers to llvm dialect for LLVM IR / PTX emission
from alkahest import to_stablehlo

# Emit textual MLIR in the StableHLO dialect
mlir_text = to_stablehlo(expr, [x, y], fn_name="my_fn")
print(mlir_text)  # valid input to mlir-opt / XLA

GPU codegen (NVPTX)

With --features cuda and an LLVM installation with NVPTX support:

from alkahest import compile_cuda

f_gpu = compile_cuda(expr, [x, y])
result = f_gpu.call_batch(inputs)   # runs on the first CUDA device

The GPU compiler:

  1. Lowers the expression through inkwell to NVPTX LLVM IR for sm_86 (Ampere)
  2. Links libdevice.10.bc for transcendental functions (__nv_sin, etc.)
  3. Emits PTX via LLVM’s target machine
  4. Loads the PTX via the CUDA driver (cudarc)

The benchmark nvptx/nvptx_polynomial_1M shows 16.2× speedup over the CPU JIT on a 1M-point polynomial evaluation on an RTX 3090.

Upcoming (v1.1): AMD ROCm / amdgcn target (hardware-blocked pending RDNA3 availability).

Caching

Compilation results are cached keyed by the canonical hash of the expression DAG. Compiling the same expression twice returns the cached result. The persistent ExprPool (V1-14) extends this cache across sessions.

Small expressions below a complexity threshold skip LLVM entirely and run through the Rust interpreter, which has lower overhead for trivial expressions.

Ball arithmetic

Ball arithmetic provides rigorous enclosures: every operation produces an interval guaranteed to contain the true result. Alkahest uses FLINT’s Arb library as the backend.

ArbBall

An ArbBall represents the real interval [midpoint ± radius]:

from alkahest import ArbBall

a = ArbBall(2.0, 0.5)    # [1.5, 2.5]
b = ArbBall(3.0, 0.0)    # exactly 3.0

print(a.mid)   # 2.0
print(a.rad)   # 0.5
print(a.lo)    # 1.5
print(a.hi)    # 2.5

An ArbBall can also carry a precision (in bits) for the midpoint:

a = ArbBall(2.0, 1e-30, prec=128)  # 128-bit midpoint

Ball arithmetic operations

All arithmetic on ArbBall values produces a guaranteed enclosure. The radius grows to account for rounding and operation error:

a = ArbBall(2.0, 0.1)
b = ArbBall(3.0, 0.1)

print(a + b)    # [4.8, 5.2]  — radius grows by sum of input radii
print(a * b)    # guaranteed enclosure of [1.9, 2.1] * [2.9, 3.1]
print(a ** 2)   # [3.24, 4.41]  (squares the interval)

interval_eval

interval_eval evaluates a symbolic expression with ArbBall inputs:

from alkahest import ExprPool, ArbBall, interval_eval, sin, exp

pool = ExprPool()
x = pool.symbol("x")

# sin(1 ± 1e-10) — guaranteed enclosure
result = interval_eval(sin(x), {x: ArbBall(1.0, 1e-10)})
print(result.lo, result.hi)

# Multivariate
y = pool.symbol("y")
expr = sin(x) * exp(y)
result = interval_eval(expr, {
    x: ArbBall(1.0, 0.01),
    y: ArbBall(0.0, 0.01),
})

interval_eval guarantees that the output ball contains the true value for any input in the given input balls, accounting for all rounding in the intermediate computation.

AcbBall

Complex ball arithmetic for expressions over ℂ:

from alkahest import AcbBall

z = AcbBall(1.0, 0.0, 1.0, 0.0)  # 1 + i, exact

Use cases

Certified numerical evaluation. Compute a value and prove it lies within a tight bound without symbolic proof:

# Prove sin(1) ∈ [0.841, 0.842]
r = interval_eval(sin(x), {x: ArbBall(1.0, 0.0)})
assert r.lo > 0.841 and r.hi < 0.842

Numerical verification of symbolic results. After deriving a symbolic simplification, verify it numerically with rigorous bounds:

# Verify sin²(x) + cos²(x) = 1 at x = 1
lhs = sin(x)**pool.integer(2) + cos(x)**pool.integer(2)
r = interval_eval(lhs, {x: ArbBall(1.0, 0.0)})
assert 1.0 in r  # ball contains 1

Sensitivity analysis. Pass an input ball representing parameter uncertainty and observe how the output uncertainty grows:

# x = 1 ± 0.1 (10% uncertainty)
r = interval_eval(x**pool.integer(3), {x: ArbBall(1.0, 0.1)})
print(r)  # output uncertainty

Relationship to Lean certificates

Ball arithmetic and Lean certificate export are complementary:

  • Ball arithmetic gives numerical certainty within floating-point computation.
  • Lean certificates give symbolic/logical certainty for the rewrite steps applied.

Combining them: certify(interval(differentiate(f))) gives a derivative, evaluated with rigorous interval bounds, with a machine-checkable proof of the symbolic differentiation step.

ODE and DAE modeling

Alkahest provides symbolic infrastructure for ordinary differential equations (ODEs) and differential-algebraic equations (DAEs), including structural analysis and automatic index reduction.

ODE

ODE represents an ordinary differential equation system. Build one from symbolic expressions:

from alkahest import ExprPool, ODE, lower_to_first_order, sin

pool = ExprPool()
t = pool.symbol("t")
x = pool.symbol("x")
v = pool.symbol("v")

# Simple harmonic oscillator: x'' + ω²x = 0
# Represented as a first-order system: [x' = v, v' = -ω²x]
omega = pool.integer(1)
ode = ODE(
    state=[x, v],
    derivatives=[v, pool.integer(-1) * omega**pool.integer(2) * x],
    independent=t,
)

Lowering to first order

Higher-order ODEs are automatically lowered to first-order form:

from alkahest import lower_to_first_order

first_order_system = lower_to_first_order(higher_order_ode)

DAE

DAE represents a differential-algebraic system where some equations are algebraic constraints rather than differential equations.

from alkahest import DAE, pantelides

pool = ExprPool()
t = pool.symbol("t")
x = pool.symbol("x")   # differential variable
y = pool.symbol("y")   # algebraic variable (constrained)
lam = pool.symbol("lam")  # Lagrange multiplier

# Pendulum: differential equations + constraint
dae = DAE(
    equations=[...],  # system of equations
    variables=[x, y, lam],
    independent=t,
)

Pantelides algorithm

The Pantelides algorithm performs structural index reduction on DAEs. It identifies which equations need to be differentiated to make the system structurally regular:

from alkahest import pantelides

reduced = pantelides(dae)
print(reduced.index)          # structural index of the reduced system
print(reduced.differentiated)  # which equations were differentiated

Index reduction converts a high-index DAE (index > 1) into an index-1 system that ODE solvers can handle. The result includes the differentiated equations as symbolic expressions.

Sensitivity analysis

Sensitivity analysis computes how solutions depend on parameters:

from alkahest import sensitivity_system, adjoint_system

pool = ExprPool()
p = pool.symbol("p")   # parameter

# Forward sensitivity: generates ∂x/∂p equations alongside the ODE
sens = sensitivity_system(ode, [p])

# Adjoint method: more efficient for many parameters, one output
adj = adjoint_system(ode, output_expr, [p])

Acausal modeling

Acausal component modeling lets you describe physical systems by their component equations without manually choosing which direction information flows:

from alkahest import AcausalSystem, Port, resistor

# Build a simple RC circuit symbolically
pool = ExprPool()
circuit = AcausalSystem()

R = resistor(pool, resistance=pool.symbol("R"))
# Connect components via ports
circuit.add(R)
circuit.connect(R.port_pos, ...)

# Extract the DAE from the connected system
dae = circuit.to_dae()

Built-in components (resistor, and others registered via the component API) generate their constitutive equations automatically. The system then assembles them into a DAE that Pantelides can reduce.

Hybrid systems

HybridODE adds event handling to an ODE: at a crossing event, the state is reset and integration resumes with a new ODE:

from alkahest import HybridODE, Event

# Bouncing ball: velocity reverses at floor contact
bounce_event = Event(
    condition=x,             # fires when x = 0
    reset={v: pool.integer(-1) * v},   # reverse velocity
)
hybrid = HybridODE(ode=base_ode, events=[bounce_event])

Polynomial system solving

Alkahest solves systems of polynomial equations symbolically using Gröbner bases.

solve

solve finds the solutions of a system of polynomial equations in a list of variables. It requires the groebner feature.

from alkahest import ExprPool, solve, sqrt

pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")

# Linear system
solutions = solve([x + y - pool.integer(1), x - y], [x, y])
# → [{x: 1/2, y: 1/2}]

# Circle intersected with a line: irrational solutions
solutions = solve(
    [x**2 + y**2 - pool.integer(1), y - x],
    [x, y]
)
# → [{x: sqrt(2)/2, y: sqrt(2)/2}, {x: -sqrt(2)/2, y: -sqrt(2)/2}]

Solutions are symbolic: irrational roots are returned as Expr trees (e.g. sqrt(2)/2) rather than floats. Quadratic elimination produces exact symbolic answers.

Solution types

The return value is a list of dicts mapping Expr variable → Expr solution:

for sol in solutions:
    for var, val in sol.items():
        print(f"{var} = {val}")
        # Evaluate numerically if needed
        from alkahest import eval_expr
        numeric = eval_expr(val, {})

solve returns an empty list for inconsistent systems and a GroebnerBasis handle for parametric families (infinite solution sets).

Upcoming (v1.1): solve will return dict[Expr, Expr] (fully symbolic) by default. A numeric=True keyword argument will convert to float as in the current behavior.

GroebnerBasis

A GroebnerBasis can be constructed directly for ideal-theoretic operations:

from alkahest import GroebnerBasis, GbPoly

# Compute a Gröbner basis under GrLex order
polys = [x**2 + y**2 - pool.integer(1), x - y]
gb = GroebnerBasis.compute(polys, [x, y])

# Check ideal membership
print(gb.contains(x - pool.rational(1, 2)))  # False

# Reduce a polynomial modulo the ideal
reduced = gb.reduce(x**3 + y**3)

Upcoming (v1.1): The Python GroebnerBasis.compute() binding will be added (the Rust implementation is already shipped).

Monomial orders

Supported orders: Lex (lexicographic), GrLex (graded lexicographic), GRevLex (graded reverse lexicographic). GRevLex is generally fastest for basis computation; Lex is required for elimination.

Parallel F4

With --features "groebner parallel", Gröbner basis computation uses Rayon for parallel S-polynomial reduction via the F4 algorithm.

GPU-accelerated Macaulay matrix (groebner-cuda)

With --features "groebner-cuda", the mod-p row reduction of the Macaulay matrix is offloaded to CUDA. Multi-prime CRT lifts reconstruct rational coefficients. Falls back to pure-Rust when no CUDA device is present.

Elimination ideals

GroebnerBasis.eliminate computes the elimination ideal by dropping generators involving specified variables:

# Eliminate y to get a univariate ideal in x
x_ideal = gb.eliminate([y])

This is the algebraic geometry operation underlying implicitization of parametric curves and surfaces.

Performance

On the solve_circle_line benchmark (2-variable quadratic system), Alkahest is approximately 40× faster than SymPy due to the FLINT-backed polynomial arithmetic and the compiled F4 core.

Upcoming (v2.0): F5 / signature-based Gröbner basis, real root isolation, primary decomposition, and other advanced algorithms.

Interoperability

Alkahest integrates with the Python numerical ecosystem at well-defined boundaries.

NumPy

Batch evaluation

numpy_eval vectorises a compiled function over NumPy arrays with zero unnecessary copies:

import numpy as np
from alkahest import ExprPool, compile_expr, numpy_eval, sin

pool = ExprPool()
x = pool.symbol("x")

f = compile_expr(sin(x) ** 2 + x, [x])
xs = np.linspace(0, 2 * np.pi, 1_000_000)
ys = numpy_eval(f, xs)   # returns a NumPy array, shape (1_000_000,)

Inputs are converted to f64 arrays via DLPack or __array__. The call is vectorised through call_batch_raw in Rust — no Python loop.

Array protocol

CompiledFn objects implement __array__ for direct NumPy coercion:

result = np.asarray(f([1.0]))  # scalar result as a 0-d array

PyTorch

PyTorch CPU tensors are accepted wherever NumPy arrays are (via __dlpack__):

import torch
xs = torch.linspace(0, 1, 10_000)
ys = numpy_eval(f, xs)   # returns a NumPy array

For GPU tensors, use the compile_cuda path (requires --features cuda), which accepts device pointers via call_device_ptrs.

JAX

numpy_eval with JAX arrays

JAX arrays implement __dlpack__ and are accepted by numpy_eval:

import jax.numpy as jnp
xs = jnp.linspace(0, 1, 10_000)
ys = numpy_eval(f, xs)

JAX primitive source (to_jax)

to_jax registers a symbolic expression as a JAX primitive, making it callable inside JAX computations including jax.jit, jax.grad, and jax.vmap:

from alkahest import to_jax, ExprPool, sin

pool = ExprPool()
x = pool.symbol("x")

jax_fn = to_jax(sin(x) ** 2, [x])

import jax
import jax.numpy as jnp

# Use inside jax.jit / jax.grad
jit_fn = jax.jit(jax_fn)
grad_fn = jax.grad(lambda x: jax_fn(x).sum())

The primitive registers:

  • A concrete def_impl that calls the Rust evaluator
  • An abstract evaluation rule for shape/dtype propagation
  • A JVP (forward-mode) rule derived from the symbolic gradient
  • A vmap batching rule

StableHLO / XLA

to_stablehlo emits textual MLIR in the StableHLO dialect, which XLA and JAX’s XLA backend can compile:

from alkahest import to_stablehlo

mlir_text = to_stablehlo(expr, [x, y], fn_name="my_kernel")
# Pass to xla_client.compile() or save to .mlir file

SymPy interop

Alkahest does not import SymPy at runtime. The integration is one-way for validation: the test oracle in tests/test_oracle.py uses SymPy as a ground truth reference. The recommended pattern for mixed workflows is to convert to/from string representation.

DLPack

All DLPack-compatible arrays (NumPy, PyTorch, JAX, CuPy) are accepted at the numpy_eval and call_device_ptrs boundaries. The DLPack conversion is zero-copy for CPU arrays with matching dtypes.

Exporting C code

emit_c generates a standalone C function for embedding in other projects:

from alkahest import emit_c

c_code = emit_c(
    sin(x) * exp(pool.integer(-1) * x),
    [x],
    var_name="x",
    fn_name="damped_sin",
)
print(c_code)
# double damped_sin(double x) { return sin(x) * exp(-x); }

The emitted code uses only standard <math.h> functions and has no Alkahest dependency.

Derivation logs

Every transformation in Alkahest returns a DerivedResult that records the exact sequence of rewrite steps applied. This log is the foundation for both human inspection and Lean proof export.

DerivedResult

DerivedResult is the return type of diff, simplify, integrate, and all top-level operations:

from alkahest import diff, sin

pool = ExprPool()
x = pool.symbol("x")

dr = diff(sin(x**2), x)

Attributes

AttributeTypeDescription
.valueExprThe result expression
.stepslist[dict]Ordered list of rewrite steps
.certificatestr | NoneLean 4 proof term, if exported
.assumptionslistSide conditions that were verified
.warningslist[str]Non-fatal issues (e.g. branch cut warning)

Rewrite steps

Each step in .steps is a dict with:

KeyValue
ruleRule name (string)
beforeExpression before the rewrite
afterExpression after the rewrite
substVariable substitution, if any
side_conditionSide condition that was checked
for step in dr.steps:
    print(f"  {step['rule']:25s}  {step['before']}  →  {step['after']}")

Side conditions

A side condition is a predicate that must hold for a rewrite to be sound:

  • Positive(x)x must be positive (e.g. for sqrt(x²) → x)
  • NonZero(x)x must be non-zero (e.g. for x/x → 1)
  • Integer(n)n must be an integer (e.g. for some power rules)
  • BranchCut(f, x) — records that f may have a branch cut at x

Side conditions propagate into the derivation log as SideCondition entries. When a side condition is not provable from the symbol’s domain, the step is still recorded but flagged. Lean export only produces a verifiable proof for steps where all side conditions are proved.

Inspecting a derivation

dr = diff(sin(x**2), x)

print(f"Result: {dr.value}")
print(f"Steps ({len(dr.steps)}):")
for step in dr.steps[:5]:
    rule = step['rule']
    before = step['before']
    after = step['after']
    print(f"  [{rule}]: {before} → {after}")
    if step.get('side_condition'):
        print(f"    side_condition: {step['side_condition']}")

DerivationLog overhead

Logging is always on and is cheap — a Vec<RewriteStep> appended to during traversal. The benchmark group log_overhead in alkahest-core/benches/alkahest_bench.rs measures logging cost separately from computation.

For production workloads where you only need .value, the steps list is still populated but you can ignore it. There is no way to disable logging in the current API (disabling it would compromise the Lean certificate pipeline).

Combining logs

When you chain operations, the logs are separate:

simplified = simplify(expr)
derived = diff(simplified.value, x)

# Full derivation: simplify steps first, then diff steps
all_steps = simplified.steps + derived.steps

For operations like integrate that internally call simplify, the log includes the simplification sub-steps interleaved with the integration steps.

Lean certificates

Alkahest can export machine-checkable Lean 4 proofs for a subset of computations.

Three levels of evidence

Derivation logs — always on, always cheap. Records every rewrite rule applied, with rule name and arguments. Human-readable; machine-parseable; forms the basis for Lean export.

Lean certificate export — for computations expressible as sequences of rewrites tagged with Lean theorem names. The library emits a .lean file containing a proof term. Lean checks it independently.

Algorithmic certificates — for operations where rewrite sequences do not work (polynomial factoring, integration by the Risch algorithm). The library emits a verifiable witness instead. For factoring, the witness is the claimed factorization, which Lean verifies by multiplying out.

Theorem mapping

Every primitive in the registry is tagged with a Lean 4 / Mathlib theorem name:

Primitive ruleMathlib theorem
diff_sinReal.hasDerivAt_sin
diff_expReal.hasDerivAt_exp
diff_logReal.hasDerivAt_log
diff_chainHasDerivAt.comp
diff_addHasDerivAt.add
diff_mulHasDerivAt.mul
add_zeroadd_zero
mul_onemul_one

The full mapping lives in alkahest-core/src/lean/.

Exporting a certificate

from alkahest import diff, sin

pool = ExprPool()
x = pool.symbol("x", "real")

dr = diff(sin(x**2), x)

# The certificate is in dr.certificate when Lean export is enabled
if dr.certificate:
    with open("proof.lean", "w") as f:
        f.write(dr.certificate)

The emitted .lean file imports Mathlib and contains a proof term that Lean can verify:

import Mathlib.Analysis.SpecialFunctions.Trigonometric.Deriv

-- Alkahest certificate: d/dx sin(x²) = 2*x*cos(x²)
theorem alkahest_diff_sin_sq (x : ℝ) :
    HasDerivAt (fun x => Real.sin (x ^ 2)) (2 * x * Real.cos (x ^ 2)) x := by
  have h1 : HasDerivAt (fun x => x ^ 2) (2 * x) x := ...
  exact (Real.hasDerivAt_sin _).comp x h1

Lean CI

The CI pipeline (.github/workflows/lean.yml) runs on every change to lean/-tagged files:

  1. Generates proof files via tests/lean_corpus.py
  2. Compiles them with the lean compiler (with Mathlib cached)
  3. Fails the build if any proof does not typecheck

Coverage

Lean export works for computations that decompose into the tagged primitive set. Operations that currently produce certificates:

  • Polynomial differentiation (all degrees)
  • Trigonometric differentiation (sin, cos, tan, and chain compositions)
  • Exponential and logarithm differentiation
  • Basic arithmetic rewrites (add_zero, mul_one, mul_comm, etc.)

Operations that use algorithmic certificates (witness-based):

  • Polynomial factoring — the claimed factorization is verified by ring_nf in Lean
  • Polynomial GCD — verified by showing gcd divides both inputs and is a linear combination

Operations without Lean certificates (Lean theorem not yet mapped, or algorithm not expressible as rewrites):

  • Integration by the Risch algorithm
  • E-graph extraction steps

Upcoming (v2.0): Deeper Mathlib coverage including limits (Filter.Tendsto), series (HasSum), and real algebraic geometry (Polynomial.roots).

Side conditions in proofs

Side conditions (domain constraints, branch cut restrictions) are propagated into the emitted Lean proof as hypotheses. A certificate that depends on x > 0 will include hx : 0 < x as an assumption in the proof term. This makes the trust boundary explicit: the certificate is only verifiable when the side conditions hold.

Error handling

Alkahest uses a structured exception hierarchy. Every error carries a stable diagnostic code, a human-readable message, an optional source span, and an optional remediation hint.

Exception hierarchy

AlkahestError (base)
├── ConversionError   (E-POLY-*)   — expression → polynomial/rational conversion
├── DomainError       (E-DOMAIN-*) — mathematical side conditions violated
├── DiffError         (E-DIFF-*)   — differentiation failed
├── IntegrationError  (E-INT-*)    — integration failed
├── MatrixError       (E-MAT-*)    — linear algebra errors
├── OdeError          (E-ODE-*)    — ODE construction or lowering
├── DaeError          (E-DAE-*)    — DAE structural analysis
├── SolverError       (E-SOLVE-*)  — polynomial system solving
├── JitError          (E-JIT-*)    — LLVM/JIT codegen
├── CudaError         (E-CUDA-*)   — CUDA kernel launch or driver
└── PoolError         (E-POOL-*)   — ExprPool misuse

Error attributes

Every exception instance exposes:

AttributeTypeDescription
.codestrStable diagnostic code, e.g. "E-POLY-001"
.messagestrHuman-readable description
.remediationstr | NoneWhat the user should try
.spantuple[int, int] | NoneCharacter offset range in source expression
import alkahest
from alkahest import ExprPool, UniPoly, ConversionError

pool = ExprPool()
x = pool.symbol("x")

try:
    # sin(x) cannot be represented as a polynomial
    p = UniPoly.from_symbolic(alkahest.sin(x), x)
except ConversionError as e:
    print(e.code)          # E-POLY-001
    print(e.message)       # "expression contains non-polynomial term: sin(x)"
    print(e.remediation)   # "Use Expr directly, or expand sin(x) as a series first"

Common errors and remediations

ConversionError (E-POLY-*)

Raised when an expression cannot be converted to a polynomial or rational function.

CodeCauseRemediation
E-POLY-001Non-polynomial term (e.g. sin)Use Expr directly; or expand as series
E-POLY-002Non-integer exponentAlgebraic extension not yet supported
E-POLY-003Symbolic exponent (variable in exponent)Use Expr.pow, not UniPoly

DomainError (E-DOMAIN-*)

Raised when a mathematical side condition is violated.

CodeCauseRemediation
E-DOMAIN-001Division by zeroCheck denominator before dividing
E-DOMAIN-002log(0) or log(negative)Ensure argument is positive; use complex domain if needed
E-DOMAIN-003sqrt(negative)Use AcbBall or declare complex domain

IntegrationError (E-INT-*)

CodeCauseRemediation
E-INT-001No integration rule matchesResult may not have an elementary antiderivative
E-INT-002Algebraic extension requiredPlanned for v1.1 (algebraic Risch)
E-INT-003Risch gave up (transcendental tower too deep)Try numerical integration

SolverError (E-SOLVE-*)

CodeCauseRemediation
E-SOLVE-001System is inconsistentNo solutions exist
E-SOLVE-002High-degree univariate factor (> 2)Symbolic solution not supported; use numerical solve
E-SOLVE-003Gröbner basis did not terminateIncrease node/iteration limits

Catching errors by code

For programmatic error handling:

try:
    result = alkahest.integrate(expr, x)
except alkahest.AlkahestError as e:
    if e.code.startswith("E-INT-"):
        print(f"Integration failed: {e.remediation}")
    else:
        raise

Error taxonomy

Every error is classified on two independent axes: subsystem (determines the code prefix and exception class) and cause (informs the remediation hint).

Subsystem axis

PrefixClassScope
E-POLY-*ConversionErrorExpression → polynomial/rational-function conversion
E-DOMAIN-*DomainErrorSide-condition violations (div-by-zero, log of 0, sqrt of negative)
E-DIFF-*DiffErrorForward/reverse differentiation, unknown derivatives
E-INT-*IntegrationErrorSymbolic integration (Risch, heuristic, table)
E-MAT-*MatrixErrorLinear algebra (shape, singular, non-invertible)
E-ODE-*OdeErrorODE construction, lowering, event handling
E-DAE-*DaeErrorDAE structural analysis (Pantelides, index reduction)
E-SOLVE-*SolverErrorPolynomial system solving, Gröbner basis
E-JIT-*JitErrorLLVM/Cranelift codegen and linking
E-CUDA-*CudaErrorNVPTX compile, kernel launch, driver/runtime failures
E-POOL-*PoolErrorExprPool misuse (closed, cross-pool, persisted-handle mismatch)
E-PARSE-*ParseError (reserved)Parser integration — owns span() by default
E-IO-*IoError (reserved)Checkpoint/serde paths (PoolPersistError)

Cause axis

  1. User-input — the expression or argument is outside the supported fragment. Always has a remediation; carries a span once parsing lands.
  2. Domain — input is syntactically fine but violates a mathematical side condition. Remediation is “substitute a different value,” not “reformulate.”
  3. Unsupported — the operation is not implemented for this case. Must name the missing capability so users can file a feature request.
  4. Resource/environment — CUDA device absent, out-of-memory, JIT target mismatch, pool closed. Typically no span; remediation references the environment, not the expression.
  5. Internal invariant — a bug. Should never reach users in release; in debug it carries a backtrace. Use E-INTERNAL-001.

Adding a new error code

  1. Does it fit an existing subsystem? Add a variant and a code one higher than the current max for that prefix.
  2. Does it name a new subsystem? Add a prefix, a class, and an entry in REGISTRY in the same PR. Do not reuse prefixes across unrelated subsystems.
  3. Write the remediation before the message — if you cannot say what the user should do, the taxonomy is telling you this is an internal bug, not a user error.

Users match on subsystem (the exception class); triagers filter on cause (the code suffix and remediation text).

Stability policy

Alkahest follows semantic versioning starting at 1.0.

Stable surface

The stable surface is the API Alkahest commits to maintaining without breaking changes across a major version:

  • Rust: everything re-exported from alkahest_core::stable
  • Python: every name in alkahest.__all__ at release time

Breaking changes to the stable surface require a major-version bump (e.g. 1.x → 2.0).

Experimental surface

  • Rust: alkahest_core::experimental::*, plus anything not in stable
  • Python: alkahest.experimental.*, plus anything re-exported from the native module but not in __all__

Experimental APIs may change in any minor release. Pin a specific point release if you depend on them.

Deprecation policy

Removed stable symbols are kept as #[deprecated] shims for one full major cycle before deletion:

  1. Symbol is deprecated in 1.x with #[deprecated(since = "1.x", note = "use Y instead")]
  2. Symbol is removed in 2.0

Python deprecations emit DeprecationWarning from the point of deprecation.

Enforcement

  • cargo semver-checks — runs on every PR via .github/workflows/alkahest-semver-check.yml. Fails the PR if any stable Rust API breaks.
  • scripts/check_api_freeze.py — guards against removals from alkahest.__all__ within a major cycle.
  • CHANGELOG.md — Keep-a-Changelog format; every release documents additions, deprecations, and (in major bumps) removals.

Error codes

Diagnostic error codes (e.g. E-POLY-001) are also stable. A code introduced in 1.x will not be renumbered or removed until 2.0. New codes are added by incrementing within the existing prefix.

Diagnostic codes and their stability

Error codes are part of the stable surface from the version they first appear. See Error handling for the current code table.