Transformations
Transformations are higher-order operations that take a function and return a new function. They compose freely and operate on traced symbolic representations.
Tracing
trace symbolically executes a Python function by replacing its inputs with symbolic variables and recording the computation as an expression DAG.
import alkahest
from alkahest import ExprPool
pool = ExprPool()
@alkahest.trace(pool)
def f(x, y):
return x**2 + alkahest.sin(y)
print(f.expr) # x^2 + sin(y)
print(f.symbols) # [x, y]
The decorator takes the pool as an argument. Variable names are inferred from the function signature.
Numeric evaluation
TracedFn objects are callable with numeric values:
print(f(3.0, 0.0)) # 9.0
import numpy as np
xs = np.linspace(0, 1, 1000)
ys = np.zeros(1000)
result = f(xs, ys) # vectorised automatically
Gradient
grad differentiates a traced function symbolically with respect to all (or a subset of) its inputs:
df = alkahest.grad(f)
# df(x, y) returns [∂f/∂x, ∂f/∂y] = [2*x, cos(y)]
grads = df(1.0, 0.0) # [2.0, 1.0]
Differentiate with respect to a subset:
df_x = alkahest.grad(f, wrt=[f.symbols[0]]) # ∂f/∂x only
JIT compilation
jit wraps a traced function in the LLVM JIT backend. The first call triggers compilation; subsequent calls run the compiled code directly.
fast_f = alkahest.jit(f)
print(fast_f(3.0, 0.0)) # 9.0, via LLVM-compiled code
Vectorised evaluation is automatic when array inputs are detected:
xs = np.linspace(0, 10, 1_000_000)
ys = np.zeros_like(xs)
result = fast_f(xs, ys) # zero-copy batch path
Composing transformations
Transformations stack:
# Compiled gradient
fast_df = alkahest.jit(alkahest.grad(f))
grads = fast_df(xs, ys) # compiled, vectorised gradient
# Second derivative: grad of grad
d2f = alkahest.grad(alkahest.grad(f))
Note that grad returns a GradTracedFn, not a TracedFn. jit can be applied to GradTracedFn when it wraps a single scalar output. For multi-output cases, compile each gradient expression individually with compile_expr.
trace_fn
Functional (non-decorator) version of trace:
from alkahest import trace_fn
fn = trace_fn(lambda x, y: x * alkahest.exp(y), pool)
PyTrees
Transformations work over nested data structures (lists, dicts, tuples, dataclasses). The Python layer flattens and unflattens them automatically:
from alkahest import flatten_exprs, unflatten_exprs, map_exprs
# A system of equations as a list
eqs = [x**2 + y - pool.integer(1), x - y**2 + pool.integer(1)]
# Map simplification over the list
simplified = map_exprs(simplify, eqs)
# Flatten to a list of ExprIds and the structure descriptor
flat, treedef = flatten_exprs(eqs)
restored = unflatten_exprs(flat, treedef)
This follows the JAX pytree pattern. The Rust kernel sees only flat sequences; structure reconstruction is a Python-layer concern.
Context manager
alkahest.context sets a default pool and configuration for a block:
with alkahest.context(pool=pool, simplify=True):
z = alkahest.symbol("z") # uses the active pool
expr = z**2 + alkahest.sin(z) # auto-simplified
Inside the context, alkahest.symbol(name) creates a symbol in the active pool without passing the pool explicitly. This is a convenience wrapper — the pool is still explicit at the structural level.