Alkahest
Alkahest is a high-performance computer algebra system with a Rust core and Python API.
What it is
A general-purpose symbolic math library designed around three axes:
Performance. The Rust kernel uses hash-consed directed acyclic graphs so structural equality is a pointer comparison and subexpression sharing is automatic. FLINT backs polynomial arithmetic. An LLVM JIT compiles symbolic expressions to native or GPU code at runtime. Common operations run orders of magnitude faster than SymPy.
Correctness. Every simplification and transformation produces a derivation log — the exact sequence of rewrite rules applied, with arguments and side conditions. A subset of operations can export Lean 4 proof terms verifiable by an independent checker.
Ergonomics. The Python API uses operator overloading for natural expression construction. Results are rich objects with .value, .steps, and .certificate attributes. Error messages carry structured codes, location information, and suggested remediations.
Design principles
Explicit representations. The type system distinguishes UniPoly (FLINT-backed univariate polynomial), MultiPoly (sparse multivariate), RationalFunction, and the generic Expr tree. Converting between them is an explicit call. There are no silent representation changes hiding performance cliffs.
Stateless by design. No global assumption contexts. No hidden caches that change behavior. All context (domains, simplification policy, precision) is passed explicitly or bundled into expression structure. This makes results deterministic and parallelism safe.
Composable transformations. trace, grad, jit, and certify operate on a shared traced representation and stack freely: jit(grad(f)) compiles a derivative, jit(grad(grad(f))) compiles a second derivative.
A small primitive set. Each primitive (sin, exp, add, mul, …) registers a full bundle: simplification rule, forward- and reverse-mode differentiation, MLIR lowering, Lean theorem tag, numerical evaluation. New operations are added by registering a primitive, not by adding code paths across the system.
Compared to alternatives
| SymPy | SageMath | Symbolics.jl | Alkahest | |
|---|---|---|---|---|
| Performance | Slow | Moderate | Fast | Fast |
| GPU codegen | No | No | No | Yes |
| Lean proofs | No | No | No | Yes |
| Python API | Yes | Yes | No (Julia) | Yes |
| Open source | Yes | Yes | Yes | Yes |
This guide
The guide covers the Rust-level design concepts. For the Python API reference see the Sphinx docs alongside this guide.
Getting started
Install
Alkahest is built with maturin.
pip install maturin
git clone https://github.com/alkahest/alkahest
cd alkahest
maturin develop --release
To enable optional features:
# LLVM JIT for native compiled evaluation
maturin develop --release --features jit
# E-graph simplification (egglog)
maturin develop --release --features egraph
# Parallel simplification (sharded ExprPool)
maturin develop --release --features parallel
# Gröbner basis solver
maturin develop --release --features groebner
# CUDA / NVPTX codegen (requires CUDA toolkit and LLVM with NVPTX target)
maturin develop --release --features cuda
# Full build (all optional features except cuda/rocm)
maturin develop --release --features "jit egraph parallel groebner"
First steps
Every computation starts with an ExprPool. It owns all expressions; you create symbols and integers from it.
import alkahest
from alkahest import ExprPool, diff, simplify, integrate, sin, exp, cos
pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")
Building expressions
Python operators build expression trees:
expr = x**2 + pool.integer(2) * x + pool.integer(1)
print(expr) # x^2 + 2*x + 1
Math functions accept expressions:
f = sin(x**2) + exp(x * y)
Parsing expressions from strings
Use parse when the expression comes from user input or a config file:
from alkahest import parse
e = parse("x^2 + 2*x + 1", pool, {"x": x})
print(e) # x^2 + 2*x + 1
Identifiers not in the symbols dict are auto-created as symbols in pool.
Both ^ and ** denote exponentiation. See Parsing from strings
for the full syntax reference.
Simplification
r = simplify(x + pool.integer(0))
print(r.value) # x
print(r.steps) # [RewriteStep(rule='add_zero', ...)]
Differentiation
dr = diff(sin(x**2), x)
print(dr.value) # 2*x*cos(x^2)
Integration
r = integrate(exp(x), x)
print(r.value) # exp(x)
r = integrate(sin(x), x)
print(r.value) # -cos(x)
Polynomial arithmetic
from alkahest import UniPoly, RationalFunction
# Convert to FLINT-backed univariate polynomial
p = UniPoly.from_symbolic(x**3 + pool.integer(-1), x)
q = UniPoly.from_symbolic(x + pool.integer(-1), x)
print(p.gcd(q)) # x - 1
print(p // q) # x^2 + x + 1
Compiled evaluation
from alkahest import compile_expr, eval_expr
# Scalar evaluation via a dict binding
result = eval_expr(x**2 + y, {x: 3.0, y: 1.0})
print(result) # 10.0
# JIT-compiled callable
f = compile_expr(x**2 + pool.integer(1), [x])
print(f([3.0])) # 10.0
Vectorized evaluation over NumPy arrays
import numpy as np
from alkahest import compile_expr, numpy_eval
f = compile_expr(sin(x) * exp(pool.integer(-1) * x), [x])
xs = np.linspace(0, 10, 1_000_000)
ys = numpy_eval(f, xs) # vectorised; much faster than a Python loop
Context manager
with alkahest.context(pool=pool, simplify=True):
z = alkahest.symbol("z") # uses the active pool
expr = z**2 + alkahest.sin(z)
Running the examples
The examples/ directory has runnable end-to-end scripts:
PYTHONPATH=python python examples/calculus.py
PYTHONPATH=python python examples/polynomials.py
PYTHONPATH=python python examples/jit_eval.py
PYTHONPATH=python python examples/ball_arithmetic.py
PYTHONPATH=python python examples/ode_modeling.py
Kernel design
The expression kernel is the foundation everything else builds on. It lives in alkahest-core/src/kernel/.
Hash-consed DAG
Every expression is represented as a directed acyclic graph stored in an ExprPool. Nodes are interned: before inserting a new node, the pool checks whether a structurally identical node already exists. If it does, the existing ExprId is returned instead of allocating a new node.
This gives three properties:
- Structural equality is a pointer comparison.
id_a == id_biff the expressions are structurally identical. No tree traversal required. - Automatic subexpression sharing. If
sin(x²)appears in ten different expressions, there is only onesin(x²)node in memory. - Hash-based memoization is cheap. Caching the result of a transformation keyed by
ExprIdis O(1) and correct.
ExprPool
ExprPool is the intern table. It owns all expressions in a session.
pool = ExprPool()
x = pool.symbol("x") # intern a Symbol node
n = pool.integer(42) # intern an Integer node
Multiple pools are independent. An ExprId from one pool must not be mixed into another — the pool validates this in debug builds.
Persistent pool (V1-14). A pool can be serialized to disk and reopened, preserving all ExprIds across sessions:
pool.save_to("session.alkp")
pool2 = ExprPool.load_from("session.alkp")
Sharded pool. With --features parallel, the intern table uses a sharded concurrent hashmap (DashMap), allowing multiple threads to insert expressions without contention.
ExprData variants
Each interned node is one of:
| Variant | Description |
|---|---|
Symbol(name, domain) | Named variable with a domain annotation |
Integer(n) | Exact arbitrary-precision integer |
Rational(p, q) | Exact rational number |
Add(children) | N-ary addition |
Mul(children) | N-ary multiplication |
Pow(base, exp) | Exponentiation |
Call(primitive, args) | Application of a registered primitive |
Piecewise(cases) | Conditional expression |
Predicate(kind, args) | Boolean condition (inequality, equality) |
Add and Mul are n-ary: a + b + c is one Add node with three children, not two nested Add nodes. Children are sorted at construction time so that commutativity is structural — a + b and b + a produce the same interned node.
Domains
Every symbol carries a domain as part of its structural identity:
x_real = pool.symbol("x", "real")
x_complex = pool.symbol("x", "complex")
# x_real and x_complex are distinct expressions — different ExprIds
The domain is not a global assumption; it is part of what the symbol is. Simplification rules can query a symbol’s domain to decide whether a rewrite is valid (e.g. sqrt(x²) → x requires x to be non-negative).
Available domains: real, positive, nonnegative, integer, complex. The default when no domain is specified is real.
ExprId and memory
ExprId is a 32-bit index into the pool’s internal arena. It is Copy, Send, and Sync. Cloning an ExprId is free. No reference counting is needed because the pool owns all nodes; expressions are not freed until the pool is dropped.
The kernel is designed with parallelism as a first-class property. All kernel types are Send + Sync. The simplification and differentiation passes can run concurrently on disjoint ExprIds from the same pool.
Interning cost model
Interning a new node requires:
- Hash the
ExprData. - Look up in the concurrent hash map.
- On miss: allocate the node in the arena and insert into the map.
- On hit: return the existing
ExprId.
Step 4 (the common case in a running computation) is a single hash lookup plus a pointer load. The arena uses bump allocation, so step 3 is also fast.
The memory benchmark group in alkahest-core/benches/alkahest_bench.rs verifies that rebuilding an identical expression tree does not grow the pool.
Expression representations
Alkahest exposes multiple representation types rather than hiding everything behind a single Expr. This is a deliberate design decision: the representation is visible, conversions are explicit, and performance characteristics are predictable.
Choosing a representation
| If you need… | Use |
|---|---|
| General symbolic computation | Expr |
| Fast univariate polynomial arithmetic | UniPoly |
| Sparse multivariate polynomial algebra | MultiPoly |
| Rational functions with automatic cancellation | RationalFunction |
| Rigorous enclosures with error bounds | ArbBall |
Conversion to a specialized type is always an explicit opt-in:
expr = x**3 + pool.integer(-2) * x + pool.integer(1)
p = UniPoly.from_symbolic(expr, x) # explicit conversion
If the expression cannot be represented in the target type (e.g. sin(x) as a polynomial), a ConversionError is raised with a remediation hint.
Expr
The generic symbolic expression. All other types convert to and from Expr. Built by operator overloading on the Python side:
expr = x**2 + pool.integer(3) * x * y - pool.integer(1)
Operations like diff, simplify, and integrate work on Expr and return DerivedResult objects wrapping an Expr.
UniPoly
Dense univariate polynomial backed by FLINT. Coefficients are exact integers or rationals stored in a FLINT polynomial object.
from alkahest import UniPoly
# x^3 - 2x + 1
p = UniPoly.from_symbolic(x**3 + pool.integer(-2) * x + pool.integer(1), x)
print(p.degree()) # 3
print(p.coefficients()) # [1, -2, 0, 1] (constant first)
print(p.leading_coeff()) # 1
# Arithmetic — all FLINT-backed, exact
q = UniPoly.from_symbolic(x + pool.integer(-1), x)
print(p * q) # x^4 - x^3 - 2x^2 + 3x - 1
print(p.gcd(q)) # x - 1
print(p // q) # x^2 + x - 1
print(p % q) # 0
# Powers
r = UniPoly.from_symbolic(x + pool.integer(1), x)
print(r ** 3) # x^3 + 3x^2 + 3x + 1
UniPoly is the right choice when you are doing heavy univariate polynomial arithmetic (GCD chains, resultants, factorization) because FLINT applies highly optimized algorithms with exact arithmetic.
MultiPoly
Sparse multivariate polynomial over ℤ (integers). Terms are stored as a map from exponent vectors to coefficients.
from alkahest import MultiPoly
# x^2*y + x*y^2 - 1
expr = x**2 * y + x * y**2 + pool.integer(-1)
mp = MultiPoly.from_symbolic(expr, [x, y])
print(mp.total_degree()) # 3
print(mp.integer_content()) # 1
# Arithmetic
mp2 = MultiPoly.from_symbolic(x * y, [x, y])
print(mp + mp2) # x^2*y + x*y^2 + x*y - 1
print(mp * mp2) # x^3*y^2 + x^2*y^3 - x*y
Variable order matters for the exponent-vector key. Pass variables in a consistent order when constructing MultiPoly objects that will be combined.
RationalFunction
Quotient of two MultiPoly objects, automatically reduced by their GCD.
from alkahest import RationalFunction
# (x^2 - 1) / (x - 1) → normalized to x + 1
numer = x**2 + pool.integer(-1)
denom = x + pool.integer(-1)
rf = RationalFunction.from_symbolic(numer, denom, [x])
print(rf) # x + 1
# Arithmetic preserves the rational form
rf_x = RationalFunction.from_symbolic(x, pool.integer(1), [x])
rf_inv = RationalFunction.from_symbolic(pool.integer(1), x, [x])
print(rf_x + rf_inv) # (x^2 + 1) / x
GCD normalization runs at construction, so every RationalFunction is in lowest terms.
ArbBall
A real interval [midpoint ± radius] backed by FLINT’s Arb library. Arithmetic on ArbBall values produces guaranteed enclosures of the true result.
from alkahest import ArbBall, interval_eval, sin
# ArbBall(midpoint, radius, precision_bits=53)
a = ArbBall(2.0, 0.5) # [1.5, 2.5]
b = ArbBall(3.0, 0.0) # exactly 3
print(a + b) # [4.5, 5.5]
print(a * b) # [4.5, 7.5]
# Evaluate a symbolic expression rigorously
pool = ExprPool()
x = pool.symbol("x")
result = interval_eval(sin(x), {x: ArbBall(1.0, 1e-10)})
print(result.lo, result.hi) # tight enclosure of sin(1)
The output ball is guaranteed to contain the true value for any input in the input balls. This is useful for:
- Certified numerical evaluation
- Proving bounds on symbolic expressions
- Verification workflows alongside Lean certificate export
See Ball arithmetic for more detail.
Converting back to Expr
All specialized types can be converted back to a generic Expr for further symbolic manipulation:
p = UniPoly.from_symbolic(x**2 + pool.integer(1), x)
expr_again = p.to_symbolic(pool)
dr = diff(expr_again, x)
Parsing expressions from strings
alkahest.parse converts a human-readable math string into an Expr node
using a Pratt (top-down operator precedence) recursive-descent parser.
import alkahest
from alkahest import ExprPool, parse, diff, simplify
pool = ExprPool()
x = pool.symbol("x")
e = parse("x^2 + 2*x + 1", pool, {"x": x})
print(e) # x^2 + 2*x + 1
dr = diff(e, x)
print(dr.value) # 2*x + 2
Syntax
| Form | Meaning |
|---|---|
42, 3.14, 1.5e-3 | Integer or float literal |
x, alpha, x_1 | Symbol (created in pool on first use) |
a + b, a - b | Addition / subtraction |
a * b, a / b | Multiplication / division |
a ^ b, a ** b | Exponentiation (right-associative) |
-a, +a | Unary negation / identity |
(expr) | Grouping |
sin(x), atan2(y, x) | Function call (one or two arguments) |
Whitespace (spaces, tabs, newlines) is ignored everywhere.
Operator precedence
From lowest to highest:
| Level | Operators |
|---|---|
| 10 | + - (infix) |
| 20 | * / |
| 25 | Unary - + |
| 30 | ^ ** (right-associative) |
So -x^2 parses as -(x^2), not (-x)^2, and x^2^3 parses as
x^(2^3) = x^8.
Supported functions
abs, acos, asin, atan, atan2, ceil, cos, cosh, erf,
erfc, exp, floor, gamma, log, round, sign, sin, sinh,
sqrt, tan, tanh
The symbols map
By default, every new identifier is interned as a fresh pool.symbol(name).
Pass a pre-built symbols dict to bind identifiers to existing Expr
objects, or to collect the symbols that were created:
# Pre-bind x to an existing symbol
x = pool.symbol("x")
e = parse("sin(x)^2 + cos(x)^2", pool, {"x": x})
# Collect auto-created symbols after parsing
sym_map: dict = {}
e = parse("a*x^2 + b*x + c", pool, sym_map)
print(sym_map.keys()) # dict_keys(['a', 'x', 'b', 'c'])
Identifiers not in the map are created and then added to the map, so the
same string name always resolves to the same Expr within a single parse
call.
Error handling
parse raises ParseError (code E-PARSE-001) on any lexical or syntax
error. The exception’s .span attribute gives the (start, end) byte range
of the offending token, and .remediation provides a hint:
from alkahest import ParseError
try:
parse("sin(x) @ 2", pool, {"x": x})
except ParseError as e:
print(e) # unexpected character '@' at offset 7
print(e.span) # (7, 8)
try:
parse("zeta(x)", pool, {"x": x})
except ParseError as e:
print(e.remediation) # known functions: abs, acos, asin, ...
Round-trip with pretty-printing
parse is the inverse of str() for expressions built from the operators
and functions listed above:
from alkahest import latex, unicode_str
e = parse("sin(x)^2 + cos(x)^2", pool, {"x": x})
print(latex(e)) # \sin\!\left(x\right)^{2} + \cos\!\left(x\right)^{2}
print(unicode_str(e)) # sin(x)² + cos(x)²
Simplification
Alkahest provides two complementary simplification engines that operate on the same expression pool.
Rule-based simplification
simplify applies a fixed set of algebraic rewrite rules until no more apply (fixpoint). It is fast, predictable, and always terminates.
from alkahest import simplify
r = simplify(x + pool.integer(0)) # → x
r = simplify(x * pool.integer(1)) # → x
r = simplify(pool.integer(2) * pool.integer(3)) # → 6 (constant folding)
The default rule set covers:
- Identity and absorbing elements (
x + 0 → x,x * 1 → x,x * 0 → 0) - Constant folding (integer and rational arithmetic)
- Basic polynomial simplification (
x + x → 2*x,x² * x → x³) - Commutativity and associativity (normalized at construction)
Domain-specific rule sets
from alkahest import simplify_trig, simplify_log_exp, simplify_expanded
# Pythagorean identity and double-angle formulas
r = simplify_trig(sin(x)**2 + cos(x)**2) # → 1
# Logarithm and exponential cancellation (branch-cut safe)
r = simplify_log_exp(exp(log(x))) # → x (with positive domain side condition)
# Expand products and collect like terms
r = simplify_expanded((x + pool.integer(1))**3)
Customizing the rule set
from alkahest import simplify_with, make_rule
# Add a custom rule: sin²(x) → 1 - cos²(x)
my_rule = make_rule("sin_sq_to_cos", lhs=sin(x)**2, rhs=pool.integer(1) - cos(x)**2)
r = simplify_with(expr, rules=[my_rule])
Parallel simplification
from alkahest import simplify_par
# Simplify a list of expressions concurrently (requires --features parallel)
exprs = [x**i for i in range(100)]
results = simplify_par(exprs)
E-graph simplification
simplify_egraph uses equality saturation via egglog to explore many equivalent forms simultaneously before committing to the best one via a cost function.
from alkahest import simplify_egraph
# The e-graph can discover non-obvious equivalences
r = simplify_egraph(x * x - pool.integer(1)) # may factor or simplify
E-graph saturation is more powerful than rule-based simplification for some inputs but slower and has non-deterministic performance for complex expressions. See E-graph saturation for configuration options.
Choosing between the two
| Criterion | simplify | simplify_egraph |
|---|---|---|
| Speed | Fast, predictable | Slower, variable |
| Completeness | Fixed rule set | Equality saturation |
| Termination | Always | Configurable limits |
| Side conditions | Respected | Respected |
| Best for | Hot paths, cleanup | Difficult equalities |
For most workflows: use simplify (or a domain-specific variant) first. Reach for simplify_egraph when you need the system to discover a non-obvious equivalence.
Collect and normalize
Two utility passes that sit between the two engines:
from alkahest import collect_like_terms, poly_normal
# 2*x + 3*x → 5*x
r = collect_like_terms(pool.integer(2) * x + pool.integer(3) * x)
# Normalize to canonical polynomial form over given variables
r = poly_normal(x**2 + pool.integer(2) * x * y + y**2, [x, y])
Rule engine
The rule engine underlies both simplify and the e-graph backend. Rules are the atomic units of algebraic knowledge.
Anatomy of a rule
A RewriteRule has:
- A name — stable string identifier (used in derivation logs and Lean certificate output)
- A LHS pattern — an expression template with pattern variables
- A RHS template — the replacement
- Optional side conditions — predicates that must hold for the rule to fire
Pattern syntax
Patterns are regular expressions with a subset of ExprData nodes used as wildcards. From the Python side, pattern variables are Expr objects whose names start with ?:
from alkahest import make_rule, match_pattern
pool = ExprPool()
x = pool.symbol("x")
# Pattern variable — matches any subexpression
pv = pool.symbol("?a")
# Rule: ?a + 0 → ?a
add_zero = make_rule("add_zero", lhs=pv + pool.integer(0), rhs=pv)
Pattern variables capture any subexpression and must bind consistently: if ?a appears twice in the LHS it must match the same expression in both positions.
Matching
match_pattern applies a pattern to an expression and returns all match substitutions:
matches = match_pattern(sin(x)**2 + cos(x)**2, pattern)
for subst in matches:
print(subst) # dict mapping pattern variable → matched expr
The matcher is associative-commutative (AC): a + b matches b + a, and a + b + c matches any ordering.
Built-in rule sets
The rule sets loaded by simplify and the domain-specific simplifiers are:
| Function | Rules |
|---|---|
simplify | Arithmetic identities, constant folding, polynomial normalization |
simplify_trig | Pythagorean identity, double-angle and half-angle formulas |
simplify_log_exp | Log/exp cancellation (branch-cut safe subset) |
simplify_expanded | Distributive expansion, like-term collection |
Defining custom rules
from alkahest import make_rule, simplify_with
pool = ExprPool()
x = pool.symbol("x")
a = pool.symbol("?a")
b = pool.symbol("?b")
# Commutativity of subtraction rewrite: a - a → 0
self_cancel = make_rule(
"self_cancel",
lhs=a + pool.integer(-1) * a,
rhs=pool.integer(0),
)
# Apply the custom rule alongside the default set
r = simplify_with(expr, rules=[self_cancel])
Custom rules are recorded in derivation logs with the name you provide. If you tag the rule with a Lean theorem name (via the Rust PrimitiveRegistry API), the corresponding step can be exported as a Lean proof term.
Rule execution model
simplify applies rules in a fixpoint loop:
- For each node in the expression (post-order traversal):
- Try each rule in the rule set.
- If a rule matches and its side conditions are satisfied, apply it, emit a
RewriteStep, and restart the loop for the modified subtree.
- Repeat until no rule fires in a full pass.
This is an inner-outer loop strategy rather than exhaustive bottom-up application. It is fast but not complete — some sequences of rewrites require rules to be applied in a specific order. The e-graph engine removes this ordering dependency.
Side conditions
A rule can carry a side condition checked against the matched substitution:
# sqrt(x^2) → x only when x is non-negative
sqrt_sq = make_rule(
"sqrt_sq_nonneg",
lhs=sqrt(a**2),
rhs=a,
condition="nonnegative", # checked against the domain of ?a
)
Side conditions that reference symbol domains are sound: sqrt_sq_nonneg will only fire when ?a is bound to a symbol with domain positive or nonneg. They propagate into the derivation log as SideCondition entries and into Lean output as assumptions.
E-graph saturation
The e-graph backend exposes a fundamentally different approach to simplification: rather than applying rules one at a time in a fixed order, it builds a structure that represents many equivalent expressions simultaneously, then extracts the best one.
What is an e-graph?
An e-graph partitions expressions into equivalence classes (e-classes). When a rewrite rule fires, it does not replace the LHS — it adds the RHS to the same e-class as the LHS. At the end of saturation, an extraction step picks the cheapest representative from each e-class according to a cost function.
This eliminates the phase-ordering problem: rules can fire in any order without risk of committing to a suboptimal form. The e-graph remembers all explored forms and chooses among them at the end.
Using the e-graph
from alkahest import simplify_egraph, simplify_egraph_with
# Default configuration
r = simplify_egraph(expr)
# With explicit config
from alkahest import EgraphConfig # (Rust API; Python config dict in future)
r = simplify_egraph_with(expr, {"node_limit": 10_000, "iter_limit": 20})
Cost functions
The extraction step minimizes a cost function over e-class representatives. Three built-in cost functions:
| Name | Behavior |
|---|---|
SizeCost | Prefers the expression with the fewest AST nodes |
DepthCost | Prefers the shallowest expression tree |
OpCost | Assigns per-operation costs; penalizes expensive ops |
StabilityCost | Penalizes patterns that cause catastrophic cancellation |
StabilityCost is aware of numerical stability issues: it penalizes subtractive cancellation patterns and prefers numerically stable rearrangements.
Configuration
The e-graph runs until saturation (no new e-class merges) or until a limit is hit:
node_limit— maximum number of e-nodes. Once reached, saturation stops and extraction runs on the current state.iter_limit— maximum number of saturation rounds.
For large or complex expressions, saturation can be expensive. The rule-based simplify is often sufficient and should be preferred on hot paths.
Rule sets in the e-graph
The e-graph uses the same RewriteRule objects as the rule-based engine. By default it loads the arithmetic rules. Domain-specific rules (trig, log/exp) are kept separate to avoid e-class explosions on expressions that do not involve those operations.
Upcoming (v1.1): The default e-graph rule set will include trig identities (sin²+cos²→1) and safe log/exp cancellation out of the box, configurable via SimplifyConfig { include_trig_rules, include_log_exp_rules }.
When e-graphs help
The e-graph is especially powerful when:
- Multiple non-obvious rewrites must be combined in a specific order that is hard to predict.
- The “right” form is not syntactically similar to the input (e.g. factoring followed by cancellation).
- You want the globally cheapest form under a custom cost function, not just any simplified form.
It is less useful when:
- The expression is already in near-canonical form and only identity cleanup is needed.
- You need predictable performance on a hot path.
- The expression is large and associative-commutative, where the e-graph can grow combinatorially.
AC matching in the e-graph
The egglog backend handles associativity and commutativity structurally: Add and Mul children are sorted at pool-insertion time, so there is a single canonical ordering. The e-graph does not need to enumerate permutations.
This is more efficient than classical AC-completion but requires that the canonical ordering is established at construction, which the kernel enforces.
Calculus
Alkahest supports symbolic differentiation and integration with full derivation logging.
Differentiation
diff(expr, var) computes the symbolic derivative of expr with respect to var.
from alkahest import diff, sin, cos, exp, log
pool = ExprPool()
x = pool.symbol("x")
# Polynomial
dr = diff(x**3 + pool.integer(2) * x, x)
print(dr.value) # 3*x^2 + 2
# Chain rule
dr = diff(sin(x**2), x)
print(dr.value) # 2*x*cos(x^2)
# Product rule
dr = diff(x * exp(x), x)
print(dr.value) # exp(x) + x*exp(x)
# Logarithm
dr = diff(log(x**2 + pool.integer(1)), x)
print(dr.value) # 2*x / (x^2 + 1)
Registered primitives
Every primitive in the registry has a differentiation rule. The 23 currently registered primitives include:
sin, cos, tan, asin, acos, atan, atan2, sinh, cosh, tanh, exp, log, sqrt, abs, sign, erf, erfc, gamma, floor, ceil, round, min, max
Derivation log
The DerivedResult returned by diff records every rule application:
dr = diff(sin(x**2), x)
for step in dr.steps:
print(f" {step['rule']:25s} {step['before']} → {step['after']}")
Forward-mode automatic differentiation
diff_forward computes the derivative using forward-mode AD (dual numbers). It produces the same result as diff but through a different computational path:
from alkahest import diff, diff_forward
sym = diff(x**3, x)
fwd = diff_forward(x**3, x)
# fwd.value == sym.value
Forward mode is useful for checking that the symbolic rules agree with dual-number evaluation.
Symbolic gradient
symbolic_grad differentiates with respect to multiple variables:
from alkahest import symbolic_grad
pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")
expr = x**2 * y + sin(x * y)
grads = symbolic_grad(expr, [x, y])
# grads[0] = ∂/∂x = 2*x*y + y*cos(x*y)
# grads[1] = ∂/∂y = x^2 + x*cos(x*y)
For the traced-function gradient (composable with jit), see Transformations.
Integration
integrate(expr, var) computes the symbolic antiderivative of expr with respect to var.
from alkahest import integrate, sin, cos, exp
# Polynomials
r = integrate(x**3, x)
print(r.value) # x^4/4
# Known functions
r = integrate(sin(x), x)
print(r.value) # -cos(x)
r = integrate(exp(x), x)
print(r.value) # exp(x)
r = integrate(x**pool.integer(-1), x)
print(r.value) # log(x)
Integration rules
The integration engine applies rules from a table of known forms (Risch subset):
- Power rule:
∫ xⁿ dx = xⁿ⁺¹/(n+1)for integern ≠ -1 - Logarithm:
∫ 1/x dx = log(x) - Exponential tower:
∫ exp(a*x + b) dx,∫ x * exp(x) dx - Linear substitution:
∫ f(a*x + b) dx - Trigonometric:
∫ sin(x) dx,∫ cos(x) dx - Standard table entries for
erf, inverse trig, etc.
If no rule applies, integrate raises an IntegrationError with a remediation hint indicating what class of integrand would be needed (e.g. “algebraic extension required — see v1.1 algebraic Risch”).
Upcoming (v1.1): Algebraic-function Risch (Trager’s algorithm) will handle integrands involving sqrt(P(x)) and other algebraic extensions.
Verification
A common pattern is to verify an antiderivative by differentiating it back:
antideriv = integrate(expr, x).value
check = simplify(diff(antideriv, x).value)
# check.value should equal expr
Higher derivatives
Chain calls to diff:
d2 = diff(diff(sin(x), x).value, x)
print(d2.value) # -sin(x)
The derivation log of the outer diff does not include the inner steps. If you need the full trace, concatenate dr1.steps + dr2.steps.
Transformations
Transformations are higher-order operations that take a function and return a new function. They compose freely and operate on traced symbolic representations.
Tracing
trace symbolically executes a Python function by replacing its inputs with symbolic variables and recording the computation as an expression DAG.
import alkahest
from alkahest import ExprPool
pool = ExprPool()
@alkahest.trace(pool)
def f(x, y):
return x**2 + alkahest.sin(y)
print(f.expr) # x^2 + sin(y)
print(f.symbols) # [x, y]
The decorator takes the pool as an argument. Variable names are inferred from the function signature.
Numeric evaluation
TracedFn objects are callable with numeric values:
print(f(3.0, 0.0)) # 9.0
import numpy as np
xs = np.linspace(0, 1, 1000)
ys = np.zeros(1000)
result = f(xs, ys) # vectorised automatically
Gradient
grad differentiates a traced function symbolically with respect to all (or a subset of) its inputs:
df = alkahest.grad(f)
# df(x, y) returns [∂f/∂x, ∂f/∂y] = [2*x, cos(y)]
grads = df(1.0, 0.0) # [2.0, 1.0]
Differentiate with respect to a subset:
df_x = alkahest.grad(f, wrt=[f.symbols[0]]) # ∂f/∂x only
JIT compilation
jit wraps a traced function in the LLVM JIT backend. The first call triggers compilation; subsequent calls run the compiled code directly.
fast_f = alkahest.jit(f)
print(fast_f(3.0, 0.0)) # 9.0, via LLVM-compiled code
Vectorised evaluation is automatic when array inputs are detected:
xs = np.linspace(0, 10, 1_000_000)
ys = np.zeros_like(xs)
result = fast_f(xs, ys) # zero-copy batch path
Composing transformations
Transformations stack:
# Compiled gradient
fast_df = alkahest.jit(alkahest.grad(f))
grads = fast_df(xs, ys) # compiled, vectorised gradient
# Second derivative: grad of grad
d2f = alkahest.grad(alkahest.grad(f))
Note that grad returns a GradTracedFn, not a TracedFn. jit can be applied to GradTracedFn when it wraps a single scalar output. For multi-output cases, compile each gradient expression individually with compile_expr.
trace_fn
Functional (non-decorator) version of trace:
from alkahest import trace_fn
fn = trace_fn(lambda x, y: x * alkahest.exp(y), pool)
PyTrees
Transformations work over nested data structures (lists, dicts, tuples, dataclasses). The Python layer flattens and unflattens them automatically:
from alkahest import flatten_exprs, unflatten_exprs, map_exprs
# A system of equations as a list
eqs = [x**2 + y - pool.integer(1), x - y**2 + pool.integer(1)]
# Map simplification over the list
simplified = map_exprs(simplify, eqs)
# Flatten to a list of ExprIds and the structure descriptor
flat, treedef = flatten_exprs(eqs)
restored = unflatten_exprs(flat, treedef)
This follows the JAX pytree pattern. The Rust kernel sees only flat sequences; structure reconstruction is a Python-layer concern.
Context manager
alkahest.context sets a default pool and configuration for a block:
with alkahest.context(pool=pool, simplify=True):
z = alkahest.symbol("z") # uses the active pool
expr = z**2 + alkahest.sin(z) # auto-simplified
Inside the context, alkahest.symbol(name) creates a symbol in the active pool without passing the pool explicitly. This is a convenience wrapper — the pool is still explicit at the structural level.
Code generation
Alkahest can compile symbolic expressions to fast native or GPU code. Compiled code bypasses Python entirely during evaluation.
The compilation pipeline
Expressions lower through multiple IR levels:
ExprPool (hash-consed DAG)
↓ e-graph extraction + canonicalization
Canonical expression form
↓ alkahest MLIR dialect
High-level MLIR (math-aware ops: horner, poly_eval, interval_eval)
↓ lowering passes
Standard MLIR (arith, math, linalg, gpu)
↓
LLVM IR / PTX / StableHLO (depending on target)
↓
Native machine code / GPU kernel / XLA
The custom alkahest MLIR dialect is where math-aware optimizations happen: Horner’s method for polynomials, fused multiply-add emission, numerically stable rearrangements via StabilityCost.
compile_expr
compile_expr produces a callable from a symbolic expression and a list of input variables:
from alkahest import ExprPool, compile_expr, sin, cos
pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")
f = compile_expr(x**2 + sin(y), [x, y])
print(f([3.0, 0.0])) # 9.0
The callable takes a list of floats (one per variable) and returns a float. For batch evaluation see numpy_eval below.
Without --features jit, a fast Rust tree-walking interpreter is used instead of LLVM. The API is identical.
eval_expr
For one-off evaluation without compiling:
from alkahest import eval_expr
result = eval_expr(x**2 + sin(y), {x: 3.0, y: 0.0})
print(result) # 9.0
eval_expr is slower than a compiled function for repeated evaluation but has no compilation overhead.
numpy_eval
numpy_eval vectorises a compiled function over NumPy arrays via the batch path:
import numpy as np
from alkahest import numpy_eval
f = compile_expr(sin(x) * cos(x), [x])
xs = np.linspace(0, 2 * 3.14159, 1_000_000)
ys = numpy_eval(f, xs) # vectorised, zero-copy
Also accepts PyTorch CPU tensors and JAX arrays via DLPack.
Horner-form emission
horner rewrites a polynomial expression into Horner’s form, which is numerically better conditioned and faster to evaluate:
from alkahest import horner
# x^3 + 2x^2 + 3x + 4 → x*(x*(x + 2) + 3) + 4
h = horner(x**3 + pool.integer(2)*x**2 + pool.integer(3)*x + pool.integer(4), x)
emit_c emits a C function string for embedding in other projects:
from alkahest import emit_c
c_code = emit_c(expr, [x, y], fn_name="f")
# → "double f(double x, double y) { return ...; }"
MLIR dialect
The alkahest-mlir crate exposes the custom MLIR dialect. The dialect ops are:
| Op | Description |
|---|---|
alkahest.sym | Symbolic variable reference |
alkahest.const | Constant value |
alkahest.add, alkahest.mul | Arithmetic |
alkahest.pow | Exponentiation |
alkahest.horner | Horner polynomial evaluation |
alkahest.poly_eval | Generic polynomial evaluation |
alkahest.series_taylor | Taylor series evaluation |
alkahest.interval_eval | Ball arithmetic evaluation |
alkahest.rational_fn | Rational function evaluation |
Three lowering targets are available:
- ArithMath — lowers to
arith+mathMLIR dialects; usesmath.fmafor Horner chains - StableHlo — lowers to StableHLO ops for XLA/JAX integration
- Llvm — lowers to
llvmdialect for LLVM IR / PTX emission
from alkahest import to_stablehlo
# Emit textual MLIR in the StableHLO dialect
mlir_text = to_stablehlo(expr, [x, y], fn_name="my_fn")
print(mlir_text) # valid input to mlir-opt / XLA
GPU codegen (NVPTX)
With --features cuda and an LLVM installation with NVPTX support:
from alkahest import compile_cuda
f_gpu = compile_cuda(expr, [x, y])
result = f_gpu.call_batch(inputs) # runs on the first CUDA device
The GPU compiler:
- Lowers the expression through inkwell to NVPTX LLVM IR for
sm_86(Ampere) - Links
libdevice.10.bcfor transcendental functions (__nv_sin, etc.) - Emits PTX via LLVM’s target machine
- Loads the PTX via the CUDA driver (
cudarc)
The benchmark nvptx/nvptx_polynomial_1M shows 16.2× speedup over the CPU JIT on a 1M-point polynomial evaluation on an RTX 3090.
Upcoming (v1.1): AMD ROCm / amdgcn target (hardware-blocked pending RDNA3 availability).
Caching
Compilation results are cached keyed by the canonical hash of the expression DAG. Compiling the same expression twice returns the cached result. The persistent ExprPool (V1-14) extends this cache across sessions.
Small expressions below a complexity threshold skip LLVM entirely and run through the Rust interpreter, which has lower overhead for trivial expressions.
Ball arithmetic
Ball arithmetic provides rigorous enclosures: every operation produces an interval guaranteed to contain the true result. Alkahest uses FLINT’s Arb library as the backend.
ArbBall
An ArbBall represents the real interval [midpoint ± radius]:
from alkahest import ArbBall
a = ArbBall(2.0, 0.5) # [1.5, 2.5]
b = ArbBall(3.0, 0.0) # exactly 3.0
print(a.mid) # 2.0
print(a.rad) # 0.5
print(a.lo) # 1.5
print(a.hi) # 2.5
An ArbBall can also carry a precision (in bits) for the midpoint:
a = ArbBall(2.0, 1e-30, prec=128) # 128-bit midpoint
Ball arithmetic operations
All arithmetic on ArbBall values produces a guaranteed enclosure. The radius grows to account for rounding and operation error:
a = ArbBall(2.0, 0.1)
b = ArbBall(3.0, 0.1)
print(a + b) # [4.8, 5.2] — radius grows by sum of input radii
print(a * b) # guaranteed enclosure of [1.9, 2.1] * [2.9, 3.1]
print(a ** 2) # [3.24, 4.41] (squares the interval)
interval_eval
interval_eval evaluates a symbolic expression with ArbBall inputs:
from alkahest import ExprPool, ArbBall, interval_eval, sin, exp
pool = ExprPool()
x = pool.symbol("x")
# sin(1 ± 1e-10) — guaranteed enclosure
result = interval_eval(sin(x), {x: ArbBall(1.0, 1e-10)})
print(result.lo, result.hi)
# Multivariate
y = pool.symbol("y")
expr = sin(x) * exp(y)
result = interval_eval(expr, {
x: ArbBall(1.0, 0.01),
y: ArbBall(0.0, 0.01),
})
interval_eval guarantees that the output ball contains the true value for any input in the given input balls, accounting for all rounding in the intermediate computation.
AcbBall
Complex ball arithmetic for expressions over ℂ:
from alkahest import AcbBall
z = AcbBall(1.0, 0.0, 1.0, 0.0) # 1 + i, exact
Use cases
Certified numerical evaluation. Compute a value and prove it lies within a tight bound without symbolic proof:
# Prove sin(1) ∈ [0.841, 0.842]
r = interval_eval(sin(x), {x: ArbBall(1.0, 0.0)})
assert r.lo > 0.841 and r.hi < 0.842
Numerical verification of symbolic results. After deriving a symbolic simplification, verify it numerically with rigorous bounds:
# Verify sin²(x) + cos²(x) = 1 at x = 1
lhs = sin(x)**pool.integer(2) + cos(x)**pool.integer(2)
r = interval_eval(lhs, {x: ArbBall(1.0, 0.0)})
assert 1.0 in r # ball contains 1
Sensitivity analysis. Pass an input ball representing parameter uncertainty and observe how the output uncertainty grows:
# x = 1 ± 0.1 (10% uncertainty)
r = interval_eval(x**pool.integer(3), {x: ArbBall(1.0, 0.1)})
print(r) # output uncertainty
Relationship to Lean certificates
Ball arithmetic and Lean certificate export are complementary:
- Ball arithmetic gives numerical certainty within floating-point computation.
- Lean certificates give symbolic/logical certainty for the rewrite steps applied.
Combining them: certify(interval(differentiate(f))) gives a derivative, evaluated with rigorous interval bounds, with a machine-checkable proof of the symbolic differentiation step.
ODE and DAE modeling
Alkahest provides symbolic infrastructure for ordinary differential equations (ODEs) and differential-algebraic equations (DAEs), including structural analysis and automatic index reduction.
ODE
ODE represents an ordinary differential equation system. Build one from symbolic expressions:
from alkahest import ExprPool, ODE, lower_to_first_order, sin
pool = ExprPool()
t = pool.symbol("t")
x = pool.symbol("x")
v = pool.symbol("v")
# Simple harmonic oscillator: x'' + ω²x = 0
# Represented as a first-order system: [x' = v, v' = -ω²x]
omega = pool.integer(1)
ode = ODE(
state=[x, v],
derivatives=[v, pool.integer(-1) * omega**pool.integer(2) * x],
independent=t,
)
Lowering to first order
Higher-order ODEs are automatically lowered to first-order form:
from alkahest import lower_to_first_order
first_order_system = lower_to_first_order(higher_order_ode)
DAE
DAE represents a differential-algebraic system where some equations are algebraic constraints rather than differential equations.
from alkahest import DAE, pantelides
pool = ExprPool()
t = pool.symbol("t")
x = pool.symbol("x") # differential variable
y = pool.symbol("y") # algebraic variable (constrained)
lam = pool.symbol("lam") # Lagrange multiplier
# Pendulum: differential equations + constraint
dae = DAE(
equations=[...], # system of equations
variables=[x, y, lam],
independent=t,
)
Pantelides algorithm
The Pantelides algorithm performs structural index reduction on DAEs. It identifies which equations need to be differentiated to make the system structurally regular:
from alkahest import pantelides
reduced = pantelides(dae)
print(reduced.index) # structural index of the reduced system
print(reduced.differentiated) # which equations were differentiated
Index reduction converts a high-index DAE (index > 1) into an index-1 system that ODE solvers can handle. The result includes the differentiated equations as symbolic expressions.
Sensitivity analysis
Sensitivity analysis computes how solutions depend on parameters:
from alkahest import sensitivity_system, adjoint_system
pool = ExprPool()
p = pool.symbol("p") # parameter
# Forward sensitivity: generates ∂x/∂p equations alongside the ODE
sens = sensitivity_system(ode, [p])
# Adjoint method: more efficient for many parameters, one output
adj = adjoint_system(ode, output_expr, [p])
Acausal modeling
Acausal component modeling lets you describe physical systems by their component equations without manually choosing which direction information flows:
from alkahest import AcausalSystem, Port, resistor
# Build a simple RC circuit symbolically
pool = ExprPool()
circuit = AcausalSystem()
R = resistor(pool, resistance=pool.symbol("R"))
# Connect components via ports
circuit.add(R)
circuit.connect(R.port_pos, ...)
# Extract the DAE from the connected system
dae = circuit.to_dae()
Built-in components (resistor, and others registered via the component API) generate their constitutive equations automatically. The system then assembles them into a DAE that Pantelides can reduce.
Hybrid systems
HybridODE adds event handling to an ODE: at a crossing event, the state is reset and integration resumes with a new ODE:
from alkahest import HybridODE, Event
# Bouncing ball: velocity reverses at floor contact
bounce_event = Event(
condition=x, # fires when x = 0
reset={v: pool.integer(-1) * v}, # reverse velocity
)
hybrid = HybridODE(ode=base_ode, events=[bounce_event])
Polynomial system solving
Alkahest solves systems of polynomial equations symbolically using Gröbner bases.
solve
solve finds the solutions of a system of polynomial equations in a list of variables. It requires the groebner feature.
from alkahest import ExprPool, solve, sqrt
pool = ExprPool()
x = pool.symbol("x")
y = pool.symbol("y")
# Linear system
solutions = solve([x + y - pool.integer(1), x - y], [x, y])
# → [{x: 1/2, y: 1/2}]
# Circle intersected with a line: irrational solutions
solutions = solve(
[x**2 + y**2 - pool.integer(1), y - x],
[x, y]
)
# → [{x: sqrt(2)/2, y: sqrt(2)/2}, {x: -sqrt(2)/2, y: -sqrt(2)/2}]
Solutions are symbolic: irrational roots are returned as Expr trees (e.g. sqrt(2)/2) rather than floats. Quadratic elimination produces exact symbolic answers.
Solution types
The return value is a list of dicts mapping Expr variable → Expr solution:
for sol in solutions:
for var, val in sol.items():
print(f"{var} = {val}")
# Evaluate numerically if needed
from alkahest import eval_expr
numeric = eval_expr(val, {})
solve returns an empty list for inconsistent systems and a GroebnerBasis handle for parametric families (infinite solution sets).
Upcoming (v1.1): solve will return dict[Expr, Expr] (fully symbolic) by default. A numeric=True keyword argument will convert to float as in the current behavior.
GroebnerBasis
A GroebnerBasis can be constructed directly for ideal-theoretic operations:
from alkahest import GroebnerBasis, GbPoly
# Compute a Gröbner basis under GrLex order
polys = [x**2 + y**2 - pool.integer(1), x - y]
gb = GroebnerBasis.compute(polys, [x, y])
# Check ideal membership
print(gb.contains(x - pool.rational(1, 2))) # False
# Reduce a polynomial modulo the ideal
reduced = gb.reduce(x**3 + y**3)
Upcoming (v1.1): The Python GroebnerBasis.compute() binding will be added (the Rust implementation is already shipped).
Monomial orders
Supported orders: Lex (lexicographic), GrLex (graded lexicographic), GRevLex (graded reverse lexicographic). GRevLex is generally fastest for basis computation; Lex is required for elimination.
Parallel F4
With --features "groebner parallel", Gröbner basis computation uses Rayon for parallel S-polynomial reduction via the F4 algorithm.
GPU-accelerated Macaulay matrix (groebner-cuda)
With --features "groebner-cuda", the mod-p row reduction of the Macaulay matrix is offloaded to CUDA. Multi-prime CRT lifts reconstruct rational coefficients. Falls back to pure-Rust when no CUDA device is present.
Elimination ideals
GroebnerBasis.eliminate computes the elimination ideal by dropping generators involving specified variables:
# Eliminate y to get a univariate ideal in x
x_ideal = gb.eliminate([y])
This is the algebraic geometry operation underlying implicitization of parametric curves and surfaces.
Performance
On the solve_circle_line benchmark (2-variable quadratic system), Alkahest is approximately 40× faster than SymPy due to the FLINT-backed polynomial arithmetic and the compiled F4 core.
Upcoming (v2.0): F5 / signature-based Gröbner basis, real root isolation, primary decomposition, and other advanced algorithms.
Interoperability
Alkahest integrates with the Python numerical ecosystem at well-defined boundaries.
NumPy
Batch evaluation
numpy_eval vectorises a compiled function over NumPy arrays with zero unnecessary copies:
import numpy as np
from alkahest import ExprPool, compile_expr, numpy_eval, sin
pool = ExprPool()
x = pool.symbol("x")
f = compile_expr(sin(x) ** 2 + x, [x])
xs = np.linspace(0, 2 * np.pi, 1_000_000)
ys = numpy_eval(f, xs) # returns a NumPy array, shape (1_000_000,)
Inputs are converted to f64 arrays via DLPack or __array__. The call is vectorised through call_batch_raw in Rust — no Python loop.
Array protocol
CompiledFn objects implement __array__ for direct NumPy coercion:
result = np.asarray(f([1.0])) # scalar result as a 0-d array
PyTorch
PyTorch CPU tensors are accepted wherever NumPy arrays are (via __dlpack__):
import torch
xs = torch.linspace(0, 1, 10_000)
ys = numpy_eval(f, xs) # returns a NumPy array
For GPU tensors, use the compile_cuda path (requires --features cuda), which accepts device pointers via call_device_ptrs.
JAX
numpy_eval with JAX arrays
JAX arrays implement __dlpack__ and are accepted by numpy_eval:
import jax.numpy as jnp
xs = jnp.linspace(0, 1, 10_000)
ys = numpy_eval(f, xs)
JAX primitive source (to_jax)
to_jax registers a symbolic expression as a JAX primitive, making it callable inside JAX computations including jax.jit, jax.grad, and jax.vmap:
from alkahest import to_jax, ExprPool, sin
pool = ExprPool()
x = pool.symbol("x")
jax_fn = to_jax(sin(x) ** 2, [x])
import jax
import jax.numpy as jnp
# Use inside jax.jit / jax.grad
jit_fn = jax.jit(jax_fn)
grad_fn = jax.grad(lambda x: jax_fn(x).sum())
The primitive registers:
- A concrete
def_implthat calls the Rust evaluator - An abstract evaluation rule for shape/dtype propagation
- A JVP (forward-mode) rule derived from the symbolic gradient
- A vmap batching rule
StableHLO / XLA
to_stablehlo emits textual MLIR in the StableHLO dialect, which XLA and JAX’s XLA backend can compile:
from alkahest import to_stablehlo
mlir_text = to_stablehlo(expr, [x, y], fn_name="my_kernel")
# Pass to xla_client.compile() or save to .mlir file
SymPy interop
Alkahest does not import SymPy at runtime. The integration is one-way for validation: the test oracle in tests/test_oracle.py uses SymPy as a ground truth reference. The recommended pattern for mixed workflows is to convert to/from string representation.
DLPack
All DLPack-compatible arrays (NumPy, PyTorch, JAX, CuPy) are accepted at the numpy_eval and call_device_ptrs boundaries. The DLPack conversion is zero-copy for CPU arrays with matching dtypes.
Exporting C code
emit_c generates a standalone C function for embedding in other projects:
from alkahest import emit_c
c_code = emit_c(
sin(x) * exp(pool.integer(-1) * x),
[x],
var_name="x",
fn_name="damped_sin",
)
print(c_code)
# double damped_sin(double x) { return sin(x) * exp(-x); }
The emitted code uses only standard <math.h> functions and has no Alkahest dependency.
Derivation logs
Every transformation in Alkahest returns a DerivedResult that records the exact sequence of rewrite steps applied. This log is the foundation for both human inspection and Lean proof export.
DerivedResult
DerivedResult is the return type of diff, simplify, integrate, and all top-level operations:
from alkahest import diff, sin
pool = ExprPool()
x = pool.symbol("x")
dr = diff(sin(x**2), x)
Attributes
| Attribute | Type | Description |
|---|---|---|
.value | Expr | The result expression |
.steps | list[dict] | Ordered list of rewrite steps |
.certificate | str | None | Lean 4 proof term, if exported |
.assumptions | list | Side conditions that were verified |
.warnings | list[str] | Non-fatal issues (e.g. branch cut warning) |
Rewrite steps
Each step in .steps is a dict with:
| Key | Value |
|---|---|
rule | Rule name (string) |
before | Expression before the rewrite |
after | Expression after the rewrite |
subst | Variable substitution, if any |
side_condition | Side condition that was checked |
for step in dr.steps:
print(f" {step['rule']:25s} {step['before']} → {step['after']}")
Side conditions
A side condition is a predicate that must hold for a rewrite to be sound:
Positive(x)—xmust be positive (e.g. forsqrt(x²) → x)NonZero(x)—xmust be non-zero (e.g. forx/x → 1)Integer(n)—nmust be an integer (e.g. for some power rules)BranchCut(f, x)— records thatfmay have a branch cut atx
Side conditions propagate into the derivation log as SideCondition entries. When a side condition is not provable from the symbol’s domain, the step is still recorded but flagged. Lean export only produces a verifiable proof for steps where all side conditions are proved.
Inspecting a derivation
dr = diff(sin(x**2), x)
print(f"Result: {dr.value}")
print(f"Steps ({len(dr.steps)}):")
for step in dr.steps[:5]:
rule = step['rule']
before = step['before']
after = step['after']
print(f" [{rule}]: {before} → {after}")
if step.get('side_condition'):
print(f" side_condition: {step['side_condition']}")
DerivationLog overhead
Logging is always on and is cheap — a Vec<RewriteStep> appended to during traversal. The benchmark group log_overhead in alkahest-core/benches/alkahest_bench.rs measures logging cost separately from computation.
For production workloads where you only need .value, the steps list is still populated but you can ignore it. There is no way to disable logging in the current API (disabling it would compromise the Lean certificate pipeline).
Combining logs
When you chain operations, the logs are separate:
simplified = simplify(expr)
derived = diff(simplified.value, x)
# Full derivation: simplify steps first, then diff steps
all_steps = simplified.steps + derived.steps
For operations like integrate that internally call simplify, the log includes the simplification sub-steps interleaved with the integration steps.
Lean certificates
Alkahest can export machine-checkable Lean 4 proofs for a subset of computations.
Three levels of evidence
Derivation logs — always on, always cheap. Records every rewrite rule applied, with rule name and arguments. Human-readable; machine-parseable; forms the basis for Lean export.
Lean certificate export — for computations expressible as sequences of rewrites tagged with Lean theorem names. The library emits a .lean file containing a proof term. Lean checks it independently.
Algorithmic certificates — for operations where rewrite sequences do not work (polynomial factoring, integration by the Risch algorithm). The library emits a verifiable witness instead. For factoring, the witness is the claimed factorization, which Lean verifies by multiplying out.
Theorem mapping
Every primitive in the registry is tagged with a Lean 4 / Mathlib theorem name:
| Primitive rule | Mathlib theorem |
|---|---|
diff_sin | Real.hasDerivAt_sin |
diff_exp | Real.hasDerivAt_exp |
diff_log | Real.hasDerivAt_log |
diff_chain | HasDerivAt.comp |
diff_add | HasDerivAt.add |
diff_mul | HasDerivAt.mul |
add_zero | add_zero |
mul_one | mul_one |
The full mapping lives in alkahest-core/src/lean/.
Exporting a certificate
from alkahest import diff, sin
pool = ExprPool()
x = pool.symbol("x", "real")
dr = diff(sin(x**2), x)
# The certificate is in dr.certificate when Lean export is enabled
if dr.certificate:
with open("proof.lean", "w") as f:
f.write(dr.certificate)
The emitted .lean file imports Mathlib and contains a proof term that Lean can verify:
import Mathlib.Analysis.SpecialFunctions.Trigonometric.Deriv
-- Alkahest certificate: d/dx sin(x²) = 2*x*cos(x²)
theorem alkahest_diff_sin_sq (x : ℝ) :
HasDerivAt (fun x => Real.sin (x ^ 2)) (2 * x * Real.cos (x ^ 2)) x := by
have h1 : HasDerivAt (fun x => x ^ 2) (2 * x) x := ...
exact (Real.hasDerivAt_sin _).comp x h1
Lean CI
The CI pipeline (.github/workflows/lean.yml) runs on every change to lean/-tagged files:
- Generates proof files via
tests/lean_corpus.py - Compiles them with the
leancompiler (with Mathlib cached) - Fails the build if any proof does not typecheck
Coverage
Lean export works for computations that decompose into the tagged primitive set. Operations that currently produce certificates:
- Polynomial differentiation (all degrees)
- Trigonometric differentiation (
sin,cos,tan, and chain compositions) - Exponential and logarithm differentiation
- Basic arithmetic rewrites (
add_zero,mul_one,mul_comm, etc.)
Operations that use algorithmic certificates (witness-based):
- Polynomial factoring — the claimed factorization is verified by
ring_nfin Lean - Polynomial GCD — verified by showing
gcddivides both inputs and is a linear combination
Operations without Lean certificates (Lean theorem not yet mapped, or algorithm not expressible as rewrites):
- Integration by the Risch algorithm
- E-graph extraction steps
Upcoming (v2.0): Deeper Mathlib coverage including limits (Filter.Tendsto), series (HasSum), and real algebraic geometry (Polynomial.roots).
Side conditions in proofs
Side conditions (domain constraints, branch cut restrictions) are propagated into the emitted Lean proof as hypotheses. A certificate that depends on x > 0 will include hx : 0 < x as an assumption in the proof term. This makes the trust boundary explicit: the certificate is only verifiable when the side conditions hold.
Error handling
Alkahest uses a structured exception hierarchy. Every error carries a stable diagnostic code, a human-readable message, an optional source span, and an optional remediation hint.
Exception hierarchy
AlkahestError (base)
├── ConversionError (E-POLY-*) — expression → polynomial/rational conversion
├── DomainError (E-DOMAIN-*) — mathematical side conditions violated
├── DiffError (E-DIFF-*) — differentiation failed
├── IntegrationError (E-INT-*) — integration failed
├── MatrixError (E-MAT-*) — linear algebra errors
├── OdeError (E-ODE-*) — ODE construction or lowering
├── DaeError (E-DAE-*) — DAE structural analysis
├── SolverError (E-SOLVE-*) — polynomial system solving
├── JitError (E-JIT-*) — LLVM/JIT codegen
├── CudaError (E-CUDA-*) — CUDA kernel launch or driver
└── PoolError (E-POOL-*) — ExprPool misuse
Error attributes
Every exception instance exposes:
| Attribute | Type | Description |
|---|---|---|
.code | str | Stable diagnostic code, e.g. "E-POLY-001" |
.message | str | Human-readable description |
.remediation | str | None | What the user should try |
.span | tuple[int, int] | None | Character offset range in source expression |
import alkahest
from alkahest import ExprPool, UniPoly, ConversionError
pool = ExprPool()
x = pool.symbol("x")
try:
# sin(x) cannot be represented as a polynomial
p = UniPoly.from_symbolic(alkahest.sin(x), x)
except ConversionError as e:
print(e.code) # E-POLY-001
print(e.message) # "expression contains non-polynomial term: sin(x)"
print(e.remediation) # "Use Expr directly, or expand sin(x) as a series first"
Common errors and remediations
ConversionError (E-POLY-*)
Raised when an expression cannot be converted to a polynomial or rational function.
| Code | Cause | Remediation |
|---|---|---|
E-POLY-001 | Non-polynomial term (e.g. sin) | Use Expr directly; or expand as series |
E-POLY-002 | Non-integer exponent | Algebraic extension not yet supported |
E-POLY-003 | Symbolic exponent (variable in exponent) | Use Expr.pow, not UniPoly |
DomainError (E-DOMAIN-*)
Raised when a mathematical side condition is violated.
| Code | Cause | Remediation |
|---|---|---|
E-DOMAIN-001 | Division by zero | Check denominator before dividing |
E-DOMAIN-002 | log(0) or log(negative) | Ensure argument is positive; use complex domain if needed |
E-DOMAIN-003 | sqrt(negative) | Use AcbBall or declare complex domain |
IntegrationError (E-INT-*)
| Code | Cause | Remediation |
|---|---|---|
E-INT-001 | No integration rule matches | Result may not have an elementary antiderivative |
E-INT-002 | Algebraic extension required | Planned for v1.1 (algebraic Risch) |
E-INT-003 | Risch gave up (transcendental tower too deep) | Try numerical integration |
SolverError (E-SOLVE-*)
| Code | Cause | Remediation |
|---|---|---|
E-SOLVE-001 | System is inconsistent | No solutions exist |
E-SOLVE-002 | High-degree univariate factor (> 2) | Symbolic solution not supported; use numerical solve |
E-SOLVE-003 | Gröbner basis did not terminate | Increase node/iteration limits |
Catching errors by code
For programmatic error handling:
try:
result = alkahest.integrate(expr, x)
except alkahest.AlkahestError as e:
if e.code.startswith("E-INT-"):
print(f"Integration failed: {e.remediation}")
else:
raise
Error taxonomy
Every error is classified on two independent axes: subsystem (determines the code prefix and exception class) and cause (informs the remediation hint).
Subsystem axis
| Prefix | Class | Scope |
|---|---|---|
E-POLY-* | ConversionError | Expression → polynomial/rational-function conversion |
E-DOMAIN-* | DomainError | Side-condition violations (div-by-zero, log of 0, sqrt of negative) |
E-DIFF-* | DiffError | Forward/reverse differentiation, unknown derivatives |
E-INT-* | IntegrationError | Symbolic integration (Risch, heuristic, table) |
E-MAT-* | MatrixError | Linear algebra (shape, singular, non-invertible) |
E-ODE-* | OdeError | ODE construction, lowering, event handling |
E-DAE-* | DaeError | DAE structural analysis (Pantelides, index reduction) |
E-SOLVE-* | SolverError | Polynomial system solving, Gröbner basis |
E-JIT-* | JitError | LLVM/Cranelift codegen and linking |
E-CUDA-* | CudaError | NVPTX compile, kernel launch, driver/runtime failures |
E-POOL-* | PoolError | ExprPool misuse (closed, cross-pool, persisted-handle mismatch) |
E-PARSE-* | ParseError (reserved) | Parser integration — owns span() by default |
E-IO-* | IoError (reserved) | Checkpoint/serde paths (PoolPersistError) |
Cause axis
- User-input — the expression or argument is outside the supported fragment. Always has a
remediation; carries aspanonce parsing lands. - Domain — input is syntactically fine but violates a mathematical side condition. Remediation is “substitute a different value,” not “reformulate.”
- Unsupported — the operation is not implemented for this case. Must name the missing capability so users can file a feature request.
- Resource/environment — CUDA device absent, out-of-memory, JIT target mismatch, pool closed. Typically no
span; remediation references the environment, not the expression. - Internal invariant — a bug. Should never reach users in release; in debug it carries a backtrace. Use
E-INTERNAL-001.
Adding a new error code
- Does it fit an existing subsystem? Add a variant and a code one higher than the current max for that prefix.
- Does it name a new subsystem? Add a prefix, a class, and an entry in
REGISTRYin the same PR. Do not reuse prefixes across unrelated subsystems. - Write the
remediationbefore the message — if you cannot say what the user should do, the taxonomy is telling you this is an internal bug, not a user error.
Users match on subsystem (the exception class); triagers filter on cause (the code suffix and remediation text).
Stability policy
Alkahest follows semantic versioning starting at 1.0.
Stable surface
The stable surface is the API Alkahest commits to maintaining without breaking changes across a major version:
- Rust: everything re-exported from
alkahest_core::stable - Python: every name in
alkahest.__all__at release time
Breaking changes to the stable surface require a major-version bump (e.g. 1.x → 2.0).
Experimental surface
- Rust:
alkahest_core::experimental::*, plus anything not instable - Python:
alkahest.experimental.*, plus anything re-exported from the native module but not in__all__
Experimental APIs may change in any minor release. Pin a specific point release if you depend on them.
Deprecation policy
Removed stable symbols are kept as #[deprecated] shims for one full major cycle before deletion:
- Symbol is deprecated in 1.x with
#[deprecated(since = "1.x", note = "use Y instead")] - Symbol is removed in 2.0
Python deprecations emit DeprecationWarning from the point of deprecation.
Enforcement
cargo semver-checks— runs on every PR via.github/workflows/alkahest-semver-check.yml. Fails the PR if any stable Rust API breaks.scripts/check_api_freeze.py— guards against removals fromalkahest.__all__within a major cycle.CHANGELOG.md— Keep-a-Changelog format; every release documents additions, deprecations, and (in major bumps) removals.
Error codes
Diagnostic error codes (e.g. E-POLY-001) are also stable. A code introduced in 1.x will not be renumbered or removed until 2.0. New codes are added by incrementing within the existing prefix.
Diagnostic codes and their stability
Error codes are part of the stable surface from the version they first appear. See Error handling for the current code table.