diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1d49b22..0d715d0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -114,10 +114,36 @@ jobs: - name: Run sandbox integration tests run: sudo -E go test -v -run 'TestSandboxedWorker' ./internal/execution/worker/... + benchmark-wasm-e2e: + name: WASM E2E Benchmark + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + + - name: Install Dependencies + run: go mod download + + - name: Run WASM end-to-end benchmark + run: make benchmark-wasm-e2e BENCH_ARGS='--iterations 5 --warmup 2 --json-output benchmark-wasm-e2e.json' + + - name: Upload benchmark results + uses: actions/upload-artifact@v4 + with: + name: wasm-e2e-benchmark + path: benchmark-wasm-e2e.json + build_docker: name: Build Docker Image runs-on: ubuntu-latest - needs: [test, test-sandbox, build] + needs: [test, test-sandbox, benchmark-wasm-e2e, build] concurrency: group: ${{ github.ref }} cancel-in-progress: ${{ github.event_name == 'pull_request' || github.ref_name != github.event.repository.default_branch }} diff --git a/.gitignore b/.gitignore index 83e00e5..f60a2ab 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,10 @@ bin/ *.out lcov.info +# Python bytecode from local adapter/demo tests +__pycache__/ +*.py[cod] + # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/Makefile b/Makefile index cd571ce..e9d59e2 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ GOFLAGS = -ldflags "$(GOLDFLAGS)" BINARY_NAME ?= shimmy CONTAINER_ENGINE ?= docker -.PHONY: all build test test-unit test-sandbox lcov install generate-mocks update-schema +.PHONY: all build test test-unit test-sandbox benchmark-wasm-e2e lcov install generate-mocks update-schema all: build @@ -33,6 +33,9 @@ test-sandbox: -w /workspace \ shimmy-test-sandbox \ go test -v -run 'TestSandboxedWorker' ./internal/execution/worker/... + +benchmark-wasm-e2e: + scripts/benchmark-wasm-e2e.py $(BENCH_ARGS) lcov: gcov2lcov -infile=coverage.out -outfile=lcov.info diff --git a/README.md b/README.md index 2cd6fda..d9ee72d 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,10 @@ GLOBAL OPTIONS: function --arg value, -a value [ --arg value, -a value ] additional arguments for to the worker process. [$FUNCTION_ARGS] - --command value, -c value the command to invoke to start the worker process. [$FUNCTION_COMMAND] + --command value, -c value the command to invoke to start the worker process, or the WASM module path when --interface=wasm. [$FUNCTION_COMMAND] --cwd value, -d value the working directory for the worker process. [$FUNCTION_WORKING_DIR] --env value, -e value [ --env value, -e value ] additional environment variables for the worker process. [$FUNCTION_ENV] - --interface value, -i value the interface to use for worker process communication. Options: rpc, file. (default: "rpc") [$FUNCTION_INTERFACE] + --interface value, -i value the interface to use for worker communication. Options: rpc, file, wasm. (default: "rpc") [$FUNCTION_INTERFACE] --max-workers value, -n value the maximum number of worker processes to run concurrently. (default: number of CPU cores) [$FUNCTION_MAX_PROCS] rpc @@ -245,6 +245,113 @@ For example, a Wolfram Language evaluation function in `evaluation.wl` would be wolframscript -file evaluation.wl /tmp/shimmy/abc/request-data-123 /tmp/shimmy/abc/response-data-456 ``` +#### WebAssembly (`--interface wasm`, opt-in) + +The WASM interface executes a pre-built WebAssembly module in-process using +wazero. The module can be a WASI module or a small freestanding module as long +as it exports the Shimmy adapter ABI below. This is an execution backend only: +Shimmy still owns the public HTTP/API contract, request validation, command +routing, cases, and response handling. + +Shimmy does not compile evaluator source code at request time and does not infer +a source language from dependency files. Language-specific work belongs in build +or deployment recipes that produce an `eval.wasm` artifact. + +A generic WASM evaluator module must export: + +| Export | Purpose | +|--------|---------| +| `memory` | Guest linear memory. | +| `alloc(size: i32) -> i32` | Reserves memory where Shimmy writes the request JSON. | +| `evaluate(ptr: i32, len: i32) -> i32` | Executes one command and returns a response pointer. | + +Shimmy writes this internal adapter envelope into guest memory: + +```json +{ + "method": "eval", + "params": { + "response": "...", + "answer": "...", + "params": {} + } +} +``` + +The response pointer returned by `evaluate` must point at: + +```text +[p:p+4] little-endian uint32 JSON length +[p+4:p+4+len] JSON object bytes +``` + +Run a pre-built WASI module with: + +```shell +FUNCTION_INTERFACE=wasm \ +FUNCTION_WASM_MODULE=/path/to/eval.wasm \ +FUNCTION_MAX_PROCS=1 \ +shimmy serve +``` + +`FUNCTION_COMMAND=/path/to/eval.wasm` is also accepted for compatibility, but +`FUNCTION_WASM_MODULE` is clearer for new deployments. + +Example build recipes, including package-shaped evaluators: + +```shell +# Go package/module +cd examples/demo-go-package +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm ./cmd/evaluator + +# Rust crate/package +cd examples/demo-rust-package +cargo build --target wasm32-unknown-unknown --release + +# C++ package with Makefile + Zig's clang driver +cd examples/demo-cpp-package +make wasm OUT=eval.wasm + +# C/C++ with wasi-sdk also works when the project exposes the same ABI +/opt/wasi-sdk/bin/clang --target=wasm32-wasip1 ... -o eval.wasm +/opt/wasi-sdk/bin/clang++ --target=wasm32-wasip1 ... -o eval.wasm +``` + +Shimmy intentionally does not run these build commands from `shimmy serve`. +Build recipes can be overridden in Makefiles, CI, Dockerfiles, or deployment +scripts; the runtime boundary remains the pre-built `eval.wasm` module. + +The backend keeps a warm module instance pool and restores a full linear-memory +snapshot after each request. This gives warm reuse without leaking guest mutable +state between requests. Dirty-page restore, Python runtimes, Pyodide, and package +bundling are intentionally out of scope for this generic backend. + +Try the state-isolation examples. Linux, or a Linux container, is the reference +environment for evaluator build/test recipes. The scripts also run on macOS when +the same toolchain is installed, but CI/reviewer instructions should assume +Linux by default. These are intentionally small synthetic evaluators for the +Go/C++ artifact path; real language/runtime packaging such as Pyodide is a +separate profile/follow-up. + +Minimum toolchains for the example commands below: + +- `scripts/demo-wasm.sh`: Go with `GOOS=wasip1 GOARCH=wasm` support, `curl`, + and `python3`. +- `scripts/demo-cpp-wasm.sh`: the same tools plus `zig` and `file`. +- `scripts/benchmark-wasm-e2e.py`: Go with `GOOS=wasip1 GOARCH=wasm` support + and `python3`; it builds the stateful demo evaluator, starts a real Shimmy + HTTP server, and measures short eval, incorrect eval, large string payload, + host-side cases, and preview payload classes. +- Rust example tests: `rustc`/`cargo` plus + `rustup target add wasm32-unknown-unknown`. + +```shell +scripts/demo-wasm.sh +scripts/demo-cpp-wasm.sh +scripts/benchmark-wasm-e2e.py --iterations 25 --warmup 3 +go test ./internal/execution/wasm -run 'Test(GoStateful|RustCompare|CppCompare|GoPackage|RustPackage|CppPackage)Example_CompilesAndRunsThroughDispatcher' -v +``` + ### Sandboxed Execution (Linux only, experimental) Shimmy can wrap each worker process in an [nsjail](https://github.com/google/nsjail) sandbox to safely execute arbitrary, untrusted code. The sandbox provides: diff --git a/cmd/root.go b/cmd/root.go index 275c31e..f8b9f0b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -47,7 +47,7 @@ functions on arbitrary, serverless platforms.` &cli.StringFlag{ Name: "interface", Aliases: []string{"i"}, - Usage: "the interface to use for worker process communication. Options: rpc, file.", + Usage: "the interface to use for worker communication. Options: rpc, file, wasm.", Value: "rpc", Category: "function", EnvVars: []string{"FUNCTION_INTERFACE"}, @@ -55,10 +55,9 @@ functions on arbitrary, serverless platforms.` &cli.StringFlag{ Name: "command", Aliases: []string{"c"}, - Usage: "the command to invoke to start the worker process.", + Usage: "the command to invoke to start the worker process, or the WASM module path when --interface=wasm.", Category: "function", EnvVars: []string{"FUNCTION_COMMAND"}, - Required: true, }, &cli.StringFlag{ Name: "cwd", diff --git a/examples/demo-cpp-compare/README.md b/examples/demo-cpp-compare/README.md new file mode 100644 index 0000000..510b1aa --- /dev/null +++ b/examples/demo-cpp-compare/README.md @@ -0,0 +1,69 @@ +# C++ Compare Evaluator for Shimmy WASM + +This example is a minimal, self-contained C++ evaluator that compiles to a +WebAssembly module and runs through Shimmy's opt-in WASM backend. It is not a +port of an existing Lambda Feedback repository; it is a small Go/C++-style +artifact example for validating the generic WASM execution path. + +It intentionally mirrors the shape of a simple Lambda Feedback evaluator: + +- input: `response`, `answer`, and optional feedback strings in `params` +- output: `{ "command": "eval", "result": { "is_correct", "feedback" } }` + +The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` +so the demo and integration test can prove that Shimmy reuses a warm WASM +instance while restoring guest memory after each request. + +## Build + +The reference environment for this example is Linux, or a Linux container, with +Zig installed. The same command also works on macOS when Zig is installed; the +point is to rely on an explicit WASM-capable toolchain rather than the host's +default C++ compiler. + +```bash +zig c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o eval.wasm \ + evaluator.cpp +``` + +The source avoids libc/libc++ and implements only the small amount of JSON +handling needed for this evaluator, so it can be built as a small freestanding +WebAssembly module. Real C++ evaluators can use a richer build setup, but still +need to expose the same Shimmy WASM ABI: + +```text +memory +alloc(size: i32) -> i32 +evaluate(req_ptr: i32, req_len: i32) -> i32 +``` + +`evaluate` returns a pointer to `[uint32 little-endian response_len][response JSON bytes]`. + +## Run the end-to-end demo + +From the repository root: + +```bash +./scripts/demo-cpp-wasm.sh +``` + +The script builds Shimmy, compiles this evaluator to `eval.wasm`, starts Shimmy +with `FUNCTION_INTERFACE=wasm`, sends two HTTP requests, and asserts that both +requests see `guest_invocation_count == 1`. + +The Go test suite also compiles this example when `zig` is available: + +```bash +go test ./internal/execution/wasm -run TestCppCompareExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-cpp-compare/evaluator.cpp b/examples/demo-cpp-compare/evaluator.cpp new file mode 100644 index 0000000..64704c9 --- /dev/null +++ b/examples/demo-cpp-compare/evaluator.cpp @@ -0,0 +1,203 @@ +// A tiny C++ evaluation function that can be compiled directly to WebAssembly. +// +// It intentionally avoids libc/libc++ so the build is just a single freestanding +// C++ source file plus Zig's wasm-capable clang driver. The evaluator follows +// Shimmy's internal WASM ABI: +// - export memory +// - alloc(size) -> request pointer +// - evaluate(ptr, len) -> pointer to [u32 little-endian length][JSON bytes] +// +// The business logic is deliberately Lambda Feedback-shaped: compare the +// submitted response with the answer and return feedback from params. + +using u32 = unsigned int; +using i32 = int; +using usize = __SIZE_TYPE__; +using uintptr = __UINTPTR_TYPE__; + +alignas(16) static char request_buffer[256 * 1024]; +alignas(16) static char response_buffer[256 * 1024]; + +// Mutable guest state. Shimmy should restore the warm instance snapshot after +// every request, so this must be 1 for every HTTP call when FUNCTION_MAX_PROCS=1. +static u32 invocation_count = 0; + +static usize cstr_len(const char *s) { + usize n = 0; + while (s[n] != 0) n++; + return n; +} + +static bool bytes_equal(const char *a, usize a_len, const char *b, usize b_len) { + if (a_len != b_len) return false; + for (usize i = 0; i < a_len; i++) { + if (a[i] != b[i]) return false; + } + return true; +} + +static void copy_bytes(char *dst, usize &pos, const char *src, usize len) { + for (usize i = 0; i < len; i++) dst[pos++] = src[i]; +} + +static void append_cstr(char *dst, usize &pos, const char *src) { + copy_bytes(dst, pos, src, cstr_len(src)); +} + +static void append_u32(char *dst, usize &pos, u32 value) { + char tmp[10]; + usize n = 0; + do { + tmp[n++] = char('0' + (value % 10)); + value /= 10; + } while (value != 0); + while (n > 0) dst[pos++] = tmp[--n]; +} + +static void append_json_string(char *dst, usize &pos, const char *src, usize len) { + dst[pos++] = '"'; + for (usize i = 0; i < len; i++) { + char c = src[i]; + if (c == '"' || c == '\\') { + dst[pos++] = '\\'; + dst[pos++] = c; + } else if (c == '\n') { + dst[pos++] = '\\'; + dst[pos++] = 'n'; + } else { + dst[pos++] = c; + } + } + dst[pos++] = '"'; +} + +static bool match_at(const char *json, usize len, usize pos, const char *needle) { + for (usize i = 0; needle[i] != 0; i++) { + if (pos + i >= len || json[pos + i] != needle[i]) return false; + } + return true; +} + +static bool find_json_string(const char *json, usize len, const char *key, + const char *&value, usize &value_len) { + char quoted_key[96]; + usize key_pos = 0; + quoted_key[key_pos++] = '"'; + for (usize i = 0; key[i] != 0 && key_pos + 2 < sizeof(quoted_key); i++) { + quoted_key[key_pos++] = key[i]; + } + quoted_key[key_pos++] = '"'; + quoted_key[key_pos] = 0; + + for (usize i = 0; i < len; i++) { + if (!match_at(json, len, i, quoted_key)) continue; + i += key_pos; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != ':') continue; + i++; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != '"') continue; + i++; + + usize start = i; + while (i < len) { + if (json[i] == '\\') { + i += 2; + continue; + } + if (json[i] == '"') { + value = json + start; + value_len = i - start; + return true; + } + i++; + } + } + return false; +} + +static i32 write_error(const char *message) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"error\":{\"message\":"); + append_json_string(response_buffer, pos, message, cstr_len(message)); + append_cstr(response_buffer, pos, "}}" + ); + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +static i32 write_eval_response(bool is_correct, + const char *feedback, + usize feedback_len) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"command\":\"eval\",\"result\":{"); + append_cstr(response_buffer, pos, "\"is_correct\":"); + append_cstr(response_buffer, pos, is_correct ? "true" : "false"); + append_cstr(response_buffer, pos, ",\"feedback\":"); + append_json_string(response_buffer, pos, feedback, feedback_len); + append_cstr(response_buffer, pos, ",\"guest_invocation_count\":"); + append_u32(response_buffer, pos, invocation_count); + append_cstr(response_buffer, pos, ",\"snapshot_isolation_ok\":"); + append_cstr(response_buffer, pos, invocation_count == 1 ? "true" : "false"); + append_cstr(response_buffer, pos, "}}" + ); + + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +extern "C" i32 alloc(i32 size) { + if (size <= 0 || usize(size) > sizeof(request_buffer)) return 0; + return i32(uintptr(request_buffer)); +} + +extern "C" i32 evaluate(i32 req_ptr, i32 req_len) { + if (req_ptr == 0 || req_len <= 0) return write_error("empty request"); + + const char *json = reinterpret_cast(uintptr(req_ptr)); + usize len = usize(req_len); + + const char *method = nullptr; + const char *response = nullptr; + const char *answer = nullptr; + const char *correct_feedback = nullptr; + const char *incorrect_feedback = nullptr; + usize method_len = 0; + usize response_len = 0; + usize answer_len = 0; + usize correct_feedback_len = 0; + usize incorrect_feedback_len = 0; + + if (!find_json_string(json, len, "method", method, method_len)) return write_error("missing method"); + if (!bytes_equal(method, method_len, "eval", 4)) return write_error("unsupported method"); + if (!find_json_string(json, len, "response", response, response_len)) return write_error("missing response"); + if (!find_json_string(json, len, "answer", answer, answer_len)) return write_error("missing answer"); + + bool has_correct_feedback = find_json_string(json, len, "correct_response_feedback", correct_feedback, correct_feedback_len); + bool has_incorrect_feedback = find_json_string(json, len, "incorrect_response_feedback", incorrect_feedback, incorrect_feedback_len); + + invocation_count++; + + bool is_correct = bytes_equal(response, response_len, answer, answer_len); + if (is_correct) { + if (!has_correct_feedback) { + correct_feedback = "Correct"; + correct_feedback_len = 7; + } + return write_eval_response(true, correct_feedback, correct_feedback_len); + } + + if (!has_incorrect_feedback) { + incorrect_feedback = "Incorrect"; + incorrect_feedback_len = 9; + } + return write_eval_response(false, incorrect_feedback, incorrect_feedback_len); +} diff --git a/examples/demo-cpp-package/Makefile b/examples/demo-cpp-package/Makefile new file mode 100644 index 0000000..12905e1 --- /dev/null +++ b/examples/demo-cpp-package/Makefile @@ -0,0 +1,23 @@ +OUT ?= eval.wasm +ZIG ?= zig + +.PHONY: wasm clean + +wasm: + $(ZIG) c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Iinclude \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o $(OUT) \ + src/evaluator.cpp src/compare.cpp + +clean: + rm -f eval.wasm diff --git a/examples/demo-cpp-package/README.md b/examples/demo-cpp-package/README.md new file mode 100644 index 0000000..f40d315 --- /dev/null +++ b/examples/demo-cpp-package/README.md @@ -0,0 +1,35 @@ +# C++ Package Evaluator for Shimmy WASM + +This example is intentionally package-shaped rather than a single source file: + +```text +Makefile +include/compare.hpp +src/evaluator.cpp +src/compare.cpp +``` + +It demonstrates the intended boundary for real evaluators: the package build +recipe emits an `eval.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +make wasm +``` + +The Makefile uses Zig's clang-compatible C++ driver to produce a freestanding +WebAssembly module. Override the output path with: + +```bash +make wasm OUT=/tmp/eval.wasm +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestCppPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-cpp-package/include/compare.hpp b/examples/demo-cpp-package/include/compare.hpp new file mode 100644 index 0000000..1f60238 --- /dev/null +++ b/examples/demo-cpp-package/include/compare.hpp @@ -0,0 +1,11 @@ +#pragma once + +using usize = __SIZE_TYPE__; + +struct TextView { + const char *ptr; + usize len; +}; + +bool bytes_equal(TextView a, TextView b); +TextView feedback_for(bool is_correct, TextView correct_feedback, TextView incorrect_feedback); diff --git a/examples/demo-cpp-package/src/compare.cpp b/examples/demo-cpp-package/src/compare.cpp new file mode 100644 index 0000000..3869e63 --- /dev/null +++ b/examples/demo-cpp-package/src/compare.cpp @@ -0,0 +1,21 @@ +#include "compare.hpp" + +static const char kCorrect[] = "Correct!"; +static const char kIncorrect[] = "Try again."; + +bool bytes_equal(TextView a, TextView b) { + if (a.len != b.len) return false; + for (usize i = 0; i < a.len; i++) { + if (a.ptr[i] != b.ptr[i]) return false; + } + return true; +} + +TextView feedback_for(bool is_correct, TextView correct_feedback, TextView incorrect_feedback) { + if (is_correct) { + if (correct_feedback.ptr != nullptr) return correct_feedback; + return TextView{kCorrect, sizeof(kCorrect) - 1}; + } + if (incorrect_feedback.ptr != nullptr) return incorrect_feedback; + return TextView{kIncorrect, sizeof(kIncorrect) - 1}; +} diff --git a/examples/demo-cpp-package/src/evaluator.cpp b/examples/demo-cpp-package/src/evaluator.cpp new file mode 100644 index 0000000..956bdf5 --- /dev/null +++ b/examples/demo-cpp-package/src/evaluator.cpp @@ -0,0 +1,139 @@ +#include "compare.hpp" + +using u32 = unsigned int; +using i32 = int; +using uintptr = __UINTPTR_TYPE__; + +alignas(16) static char request_buffer[256 * 1024]; +alignas(16) static char response_buffer[256 * 1024]; +static u32 invocation_count = 0; + +static usize cstr_len(const char *s) { + usize n = 0; + while (s[n] != 0) n++; + return n; +} + +static void copy_bytes(char *dst, usize &pos, const char *src, usize len) { + for (usize i = 0; i < len; i++) dst[pos++] = src[i]; +} + +static void append_cstr(char *dst, usize &pos, const char *src) { + copy_bytes(dst, pos, src, cstr_len(src)); +} + +static void append_u32(char *dst, usize &pos, u32 value) { + char tmp[10]; + usize n = 0; + do { + tmp[n++] = char('0' + (value % 10)); + value /= 10; + } while (value != 0); + while (n > 0) dst[pos++] = tmp[--n]; +} + +static void append_json_string(char *dst, usize &pos, TextView src) { + dst[pos++] = '"'; + for (usize i = 0; i < src.len; i++) { + char c = src.ptr[i]; + if (c == '"' || c == '\\') { + dst[pos++] = '\\'; + dst[pos++] = c; + } else if (c == '\n') { + dst[pos++] = '\\'; + dst[pos++] = 'n'; + } else { + dst[pos++] = c; + } + } + dst[pos++] = '"'; +} + +static bool match_at(const char *json, usize len, usize pos, const char *needle) { + for (usize i = 0; needle[i] != 0; i++) { + if (pos + i >= len || json[pos + i] != needle[i]) return false; + } + return true; +} + +static bool find_json_string(const char *json, usize len, const char *key, TextView &value) { + char quoted_key[96]; + usize key_pos = 0; + quoted_key[key_pos++] = '"'; + for (usize i = 0; key[i] != 0 && key_pos + 2 < sizeof(quoted_key); i++) { + quoted_key[key_pos++] = key[i]; + } + quoted_key[key_pos++] = '"'; + quoted_key[key_pos] = 0; + + for (usize i = 0; i < len; i++) { + if (!match_at(json, len, i, quoted_key)) continue; + i += key_pos; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != ':') continue; + i++; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != '"') continue; + i++; + + usize start = i; + while (i < len) { + if (json[i] == '\\') { + i += 2; + continue; + } + if (json[i] == '"') { + value = TextView{json + start, i - start}; + return true; + } + i++; + } + } + return false; +} + +static i32 write_eval_response(bool is_correct, TextView feedback) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"command\":\"eval\",\"result\":{"); + append_cstr(response_buffer, pos, "\"is_correct\":"); + append_cstr(response_buffer, pos, is_correct ? "true" : "false"); + append_cstr(response_buffer, pos, ",\"feedback\":"); + append_json_string(response_buffer, pos, feedback); + append_cstr(response_buffer, pos, ",\"guest_invocation_count\":"); + append_u32(response_buffer, pos, invocation_count); + append_cstr(response_buffer, pos, ",\"snapshot_isolation_ok\":"); + append_cstr(response_buffer, pos, invocation_count == 1 ? "true" : "false"); + append_cstr(response_buffer, pos, "}}"); + + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +extern "C" i32 alloc(i32 size) { + if (size <= 0 || usize(size) > sizeof(request_buffer)) return 0; + return i32(uintptr(request_buffer)); +} + +extern "C" i32 evaluate(i32 req_ptr, i32 req_len) { + if (req_ptr == 0 || req_len <= 0) return 0; + const char *json = reinterpret_cast(uintptr(req_ptr)); + usize len = usize(req_len); + + TextView response{nullptr, 0}; + TextView answer{nullptr, 0}; + TextView correct_feedback{nullptr, 0}; + TextView incorrect_feedback{nullptr, 0}; + + if (!find_json_string(json, len, "response", response)) return 0; + if (!find_json_string(json, len, "answer", answer)) return 0; + find_json_string(json, len, "correct_response_feedback", correct_feedback); + find_json_string(json, len, "incorrect_response_feedback", incorrect_feedback); + + invocation_count++; + bool is_correct = bytes_equal(response, answer); + return write_eval_response(is_correct, feedback_for(is_correct, correct_feedback, incorrect_feedback)); +} diff --git a/examples/demo-go-package/README.md b/examples/demo-go-package/README.md new file mode 100644 index 0000000..b3710c4 --- /dev/null +++ b/examples/demo-go-package/README.md @@ -0,0 +1,27 @@ +# Go Package Evaluator for Shimmy WASM + +This example is intentionally package-shaped rather than a single source file: + +```text +go.mod +cmd/evaluator/main.go +internal/compare/compare.go +``` + +It demonstrates the intended boundary for real evaluators: the package build +recipe emits an `eval.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm ./cmd/evaluator +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestGoPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-go-package/cmd/evaluator/main.go b/examples/demo-go-package/cmd/evaluator/main.go new file mode 100644 index 0000000..143259b --- /dev/null +++ b/examples/demo-go-package/cmd/evaluator/main.go @@ -0,0 +1,69 @@ +//go:build wasip1 + +package main + +import ( + "encoding/binary" + "encoding/json" + "unsafe" + + "demo-go-package/internal/compare" +) + +var reqBuf [256 * 1024]byte +var respBuf [256 * 1024]byte +var invocationCount uint32 + +//go:wasmexport alloc +func alloc(size int32) int32 { + _ = size + return int32(uintptr(unsafe.Pointer(&reqBuf[0]))) +} + +//go:wasmexport evaluate +func evaluate(reqPtr int32, reqLen int32) int32 { + _ = reqPtr + var req struct { + Method string `json:"method"` + Params struct { + Response string `json:"response"` + Answer string `json:"answer"` + Params map[string]any `json:"params"` + } `json:"params"` + } + if err := json.Unmarshal(reqBuf[:reqLen], &req); err != nil { + writeResp(map[string]any{"error": map[string]any{"message": err.Error()}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + if req.Method != "eval" { + writeResp(map[string]any{"error": map[string]any{"message": "unsupported method"}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + + invocationCount++ + isCorrect := compare.IsCorrect(req.Params.Response, req.Params.Answer) + correctFeedback, _ := req.Params.Params["correct_response_feedback"].(string) + incorrectFeedback, _ := req.Params.Params["incorrect_response_feedback"].(string) + feedback := compare.Feedback(isCorrect, correctFeedback, incorrectFeedback) + writeResp(map[string]any{ + "command": "eval", + "result": map[string]any{ + "is_correct": isCorrect, + "feedback": feedback, + "guest_invocation_count": invocationCount, + "snapshot_isolation_ok": invocationCount == 1, + }, + }) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) +} + +func writeResp(v map[string]any) { + data, err := json.Marshal(v) + if err != nil { + data = []byte(`{"error":{"message":"marshal failed"}}`) + } + binary.LittleEndian.PutUint32(respBuf[:4], uint32(len(data))) + copy(respBuf[4:], data) +} + +func main() {} diff --git a/examples/demo-go-package/go.mod b/examples/demo-go-package/go.mod new file mode 100644 index 0000000..f04ed4f --- /dev/null +++ b/examples/demo-go-package/go.mod @@ -0,0 +1,3 @@ +module demo-go-package + +go 1.24 diff --git a/examples/demo-go-package/internal/compare/compare.go b/examples/demo-go-package/internal/compare/compare.go new file mode 100644 index 0000000..8978f98 --- /dev/null +++ b/examples/demo-go-package/internal/compare/compare.go @@ -0,0 +1,18 @@ +package compare + +func IsCorrect(response, answer string) bool { + return response == answer +} + +func Feedback(isCorrect bool, correctFeedback, incorrectFeedback string) string { + if isCorrect { + if correctFeedback != "" { + return correctFeedback + } + return "Correct" + } + if incorrectFeedback != "" { + return incorrectFeedback + } + return "Incorrect" +} diff --git a/examples/demo-rust-compare/README.md b/examples/demo-rust-compare/README.md new file mode 100644 index 0000000..46d9c9c --- /dev/null +++ b/examples/demo-rust-compare/README.md @@ -0,0 +1,55 @@ +# Rust Compare Evaluator for Shimmy WASM + +This example is a minimal, self-contained Rust evaluator that compiles to a +WebAssembly module and runs through Shimmy's opt-in WASM backend. It is not a +port of an existing Lambda Feedback repository; it is a small Rust artifact +example for validating the generic WASM execution path when a Rust WASM target +is installed. + +It intentionally mirrors the shape of a simple Lambda Feedback evaluator: + +- input: `response`, `answer`, and optional feedback strings in `params` +- output: `{ "command": "eval", "result": { "is_correct", "feedback" } }` + +The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` +so the integration test can prove that Shimmy reuses a warm WASM instance while +restoring guest memory after each request. + +## Build + +The reference environment is Linux, or a Linux container, with Rust installed and +the `wasm32-unknown-unknown` target available: + +```bash +cd examples/demo-rust-compare +rustup target add wasm32-unknown-unknown +rustc \ + --target wasm32-unknown-unknown \ + --crate-type cdylib \ + -C panic=abort \ + -O \ + -o eval.wasm \ + evaluator.rs +``` + +The source is `#![no_std]` and implements only the small amount of JSON handling +needed for this evaluator. Real Rust evaluators can use richer build setup, but +still need to expose the same Shimmy WASM ABI: + +```text +memory +alloc(size: i32) -> i32 +evaluate(req_ptr: i32, req_len: i32) -> i32 +``` + +`evaluate` returns a pointer to `[uint32 little-endian response_len][response JSON bytes]`. + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestRustCompareExample_CompilesAndRunsThroughDispatcher -v +``` + +The test skips when `rustc` or the `wasm32-unknown-unknown` target is unavailable. diff --git a/examples/demo-rust-compare/evaluator.rs b/examples/demo-rust-compare/evaluator.rs new file mode 100644 index 0000000..1ef5d6e --- /dev/null +++ b/examples/demo-rust-compare/evaluator.rs @@ -0,0 +1,177 @@ +#![no_std] +#![no_main] + +use core::panic::PanicInfo; + +static mut REQ_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut RESP_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; + +// Deliberately mutable guest state. Shimmy snapshots this after startup and +// restores it after every request, so each call should observe count == 1. +static mut INVOCATION_COUNT: u32 = 0; + +#[panic_handler] +fn panic(_info: &PanicInfo) -> ! { + loop {} +} + +#[no_mangle] +pub extern "C" fn alloc(_size: i32) -> i32 { + core::ptr::addr_of_mut!(REQ_BUF) as *mut u8 as i32 +} + +#[no_mangle] +pub extern "C" fn evaluate(_req_ptr: i32, req_len: i32) -> i32 { + unsafe { + INVOCATION_COUNT += 1; + } + + let req = unsafe { + core::slice::from_raw_parts(core::ptr::addr_of!(REQ_BUF) as *const u8, req_len as usize) + }; + + let response = json_string_field(req, b"response"); + let answer = json_string_field(req, b"answer"); + let correct_feedback = + json_string_field(req, b"correct_response_feedback").unwrap_or(b"Correct!"); + let incorrect_feedback = + json_string_field(req, b"incorrect_response_feedback").unwrap_or(b"Try again."); + + let is_correct = response.is_some() && answer.is_some() && response == answer; + let feedback = if is_correct { + correct_feedback + } else { + incorrect_feedback + }; + let count = unsafe { INVOCATION_COUNT }; + + write_response(is_correct, feedback, count) +} + +fn json_string_field<'a>(input: &'a [u8], name: &[u8]) -> Option<&'a [u8]> { + let key = find_key(input, name)?; + let mut i = key + name.len() + 2; // leading quote + name + trailing quote + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b':' { + return None; + } + i += 1; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b'"' { + return None; + } + i += 1; + let start = i; + while i < input.len() { + match input[i] { + b'"' => return Some(&input[start..i]), + b'\\' => i += 2, + _ => i += 1, + } + } + None +} + +fn find_key(input: &[u8], name: &[u8]) -> Option { + if input.len() < name.len() + 2 { + return None; + } + let last = input.len() - name.len() - 1; + let mut i = 0; + while i < last { + if input[i] == b'"' + && &input[i + 1..i + 1 + name.len()] == name + && input[i + 1 + name.len()] == b'"' + { + return Some(i); + } + i += 1; + } + None +} + +fn write_response(is_correct: bool, feedback: &[u8], count: u32) -> i32 { + let mut w = Writer::new(); + w.bytes(b"{\"command\":\"eval\",\"result\":{\"is_correct\":"); + w.bytes(if is_correct { b"true" } else { b"false" }); + w.bytes(b",\"feedback\":\""); + w.json_string_bytes(feedback); + w.bytes(b"\",\"guest_invocation_count\":"); + w.u32(count); + w.bytes(b",\"snapshot_isolation_ok\":"); + w.bytes(if count == 1 { b"true" } else { b"false" }); + w.bytes(b"}}"); + w.finish() +} + +struct Writer { + pos: usize, +} + +impl Writer { + fn new() -> Self { + Self { pos: 4 } + } + + fn bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + self.push(b); + } + } + + fn json_string_bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + match b { + b'"' => self.bytes(b"\\\""), + b'\\' => self.bytes(b"\\\\"), + _ => self.push(b), + } + } + } + + fn u32(&mut self, mut n: u32) { + if n == 0 { + self.push(b'0'); + return; + } + let mut digits = [0u8; 10]; + let mut len = 0; + while n > 0 { + digits[len] = b'0' + (n % 10) as u8; + n /= 10; + len += 1; + } + while len > 0 { + len -= 1; + self.push(digits[len]); + } + } + + fn push(&mut self, b: u8) { + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(self.pos) = b; + } + self.pos += 1; + } + + fn finish(self) -> i32 { + let len = (self.pos - 4) as u32; + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(0) = (len & 0xff) as u8; + *ptr.add(1) = ((len >> 8) & 0xff) as u8; + *ptr.add(2) = ((len >> 16) & 0xff) as u8; + *ptr.add(3) = ((len >> 24) & 0xff) as u8; + ptr as i32 + } + } +} diff --git a/examples/demo-rust-package/Cargo.lock b/examples/demo-rust-package/Cargo.lock new file mode 100644 index 0000000..70fcca9 --- /dev/null +++ b/examples/demo-rust-package/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "demo-rust-package" +version = "0.1.0" diff --git a/examples/demo-rust-package/Cargo.toml b/examples/demo-rust-package/Cargo.toml new file mode 100644 index 0000000..c413091 --- /dev/null +++ b/examples/demo-rust-package/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "demo-rust-package" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[profile.release] +panic = "abort" +lto = true +opt-level = "z" diff --git a/examples/demo-rust-package/README.md b/examples/demo-rust-package/README.md new file mode 100644 index 0000000..3af2594 --- /dev/null +++ b/examples/demo-rust-package/README.md @@ -0,0 +1,35 @@ +# Rust Package Evaluator for Shimmy WASM + +This example is intentionally crate-shaped rather than a single source file: + +```text +Cargo.toml +src/lib.rs +src/compare.rs +``` + +It demonstrates the intended boundary for real evaluators: the crate build +recipe emits a `.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +cd examples/demo-rust-package +rustup target add wasm32-unknown-unknown +cargo build --target wasm32-unknown-unknown --release +``` + +The output module is: + +```text +target/wasm32-unknown-unknown/release/demo_rust_package.wasm +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestRustPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-rust-package/src/compare.rs b/examples/demo-rust-package/src/compare.rs new file mode 100644 index 0000000..23cc2a9 --- /dev/null +++ b/examples/demo-rust-package/src/compare.rs @@ -0,0 +1,15 @@ +pub fn is_correct(response: Option<&[u8]>, answer: Option<&[u8]>) -> bool { + response.is_some() && answer.is_some() && response == answer +} + +pub fn feedback<'a>( + is_correct: bool, + correct_feedback: Option<&'a [u8]>, + incorrect_feedback: Option<&'a [u8]>, +) -> &'a [u8] { + if is_correct { + correct_feedback.unwrap_or(b"Correct!") + } else { + incorrect_feedback.unwrap_or(b"Try again.") + } +} diff --git a/examples/demo-rust-package/src/lib.rs b/examples/demo-rust-package/src/lib.rs new file mode 100644 index 0000000..e54fd2d --- /dev/null +++ b/examples/demo-rust-package/src/lib.rs @@ -0,0 +1,170 @@ +#![no_std] +#![no_main] + +mod compare; + +use core::panic::PanicInfo; + +static mut REQ_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut RESP_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut INVOCATION_COUNT: u32 = 0; + +#[panic_handler] +fn panic(_info: &PanicInfo) -> ! { + loop {} +} + +#[no_mangle] +pub extern "C" fn alloc(_size: i32) -> i32 { + core::ptr::addr_of_mut!(REQ_BUF) as *mut u8 as i32 +} + +#[no_mangle] +pub extern "C" fn evaluate(_req_ptr: i32, req_len: i32) -> i32 { + unsafe { + INVOCATION_COUNT += 1; + } + + let req = unsafe { + core::slice::from_raw_parts(core::ptr::addr_of!(REQ_BUF) as *const u8, req_len as usize) + }; + + let response = json_string_field(req, b"response"); + let answer = json_string_field(req, b"answer"); + let correct_feedback = json_string_field(req, b"correct_response_feedback"); + let incorrect_feedback = json_string_field(req, b"incorrect_response_feedback"); + + let is_correct = compare::is_correct(response, answer); + let feedback = compare::feedback(is_correct, correct_feedback, incorrect_feedback); + let count = unsafe { INVOCATION_COUNT }; + + write_response(is_correct, feedback, count) +} + +fn json_string_field<'a>(input: &'a [u8], name: &[u8]) -> Option<&'a [u8]> { + let key = find_key(input, name)?; + let mut i = key + name.len() + 2; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b':' { + return None; + } + i += 1; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b'"' { + return None; + } + i += 1; + let start = i; + while i < input.len() { + match input[i] { + b'"' => return Some(&input[start..i]), + b'\\' => i += 2, + _ => i += 1, + } + } + None +} + +fn find_key(input: &[u8], name: &[u8]) -> Option { + if input.len() < name.len() + 2 { + return None; + } + let last = input.len() - name.len() - 1; + let mut i = 0; + while i < last { + if input[i] == b'"' + && &input[i + 1..i + 1 + name.len()] == name + && input[i + 1 + name.len()] == b'"' + { + return Some(i); + } + i += 1; + } + None +} + +fn write_response(is_correct: bool, feedback: &[u8], count: u32) -> i32 { + let mut w = Writer::new(); + w.bytes(b"{\"command\":\"eval\",\"result\":{\"is_correct\":"); + w.bytes(if is_correct { b"true" } else { b"false" }); + w.bytes(b",\"feedback\":\""); + w.json_string_bytes(feedback); + w.bytes(b"\",\"guest_invocation_count\":"); + w.u32(count); + w.bytes(b",\"snapshot_isolation_ok\":"); + w.bytes(if count == 1 { b"true" } else { b"false" }); + w.bytes(b"}}"); + w.finish() +} + +struct Writer { + pos: usize, +} + +impl Writer { + fn new() -> Self { + Self { pos: 4 } + } + + fn bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + self.push(b); + } + } + + fn json_string_bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + match b { + b'"' => self.bytes(b"\\\""), + b'\\' => self.bytes(b"\\\\"), + _ => self.push(b), + } + } + } + + fn u32(&mut self, mut n: u32) { + if n == 0 { + self.push(b'0'); + return; + } + let mut digits = [0u8; 10]; + let mut len = 0; + while n > 0 { + digits[len] = b'0' + (n % 10) as u8; + n /= 10; + len += 1; + } + while len > 0 { + len -= 1; + self.push(digits[len]); + } + } + + fn push(&mut self, b: u8) { + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(self.pos) = b; + } + self.pos += 1; + } + + fn finish(self) -> i32 { + let len = (self.pos - 4) as u32; + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(0) = (len & 0xff) as u8; + *ptr.add(1) = ((len >> 8) & 0xff) as u8; + *ptr.add(2) = ((len >> 16) & 0xff) as u8; + *ptr.add(3) = ((len >> 24) & 0xff) as u8; + ptr as i32 + } + } +} diff --git a/examples/demo-stateful/README.md b/examples/demo-stateful/README.md new file mode 100644 index 0000000..b38ef5a --- /dev/null +++ b/examples/demo-stateful/README.md @@ -0,0 +1,32 @@ +# Demo: stateful WASM evaluator + +This is a deliberately tiny Shimmy-WASM evaluator for live demos. + +It mutates a module-global `invocationCount` on every `eval` call. In a normal +warm worker, that state would leak across requests (`1`, then `2`, then `3`). +Shimmy-WASM snapshots the module memory after startup and restores it after each +request, so every request observes `guest_invocation_count: 1`. + +Use it via the one-command demo runner: + +```bash +scripts/demo-wasm.sh +``` + +The Go test suite also compiles this example when the local Go toolchain +supports `GOOS=wasip1` and `//go:wasmexport`: + +```bash +go test ./internal/execution/wasm -run TestGoStatefulExample_CompilesAndRunsThroughDispatcher -v +``` + +What the demo shows: + +1. Build Shimmy. +2. Compile this evaluator to `wasm32-wasip1`. +3. Start `shimmy serve` with `FUNCTION_INTERFACE=wasm`. +4. Send two HTTP grading requests. +5. Assert both responses report `guest_invocation_count == 1`. + +That is the visible end-to-end proof: HTTP request → Shimmy → wazero WASM guest +→ response validation → HTTP response, with per-request state reset. diff --git a/examples/demo-stateful/main.go b/examples/demo-stateful/main.go new file mode 100644 index 0000000..80125a4 --- /dev/null +++ b/examples/demo-stateful/main.go @@ -0,0 +1,95 @@ +//go:build wasip1 + +// demo-stateful is a tiny Shimmy-WASM evaluation function for live demos. +// It intentionally mutates module-global state on every request. The host +// should snapshot/restore WASM memory after each call, so the counter reported +// to the next request should still be 1 rather than leaking as 2, 3, ... +package main + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "unsafe" +) + +var reqBuf [256 * 1024]byte +var respBuf [256 * 1024]byte + +// This is deliberately mutable guest state. A non-isolated warm worker would +// leak it across requests; Shimmy-WASM restores the memory snapshot instead. +var invocationCount uint32 +var lastResponse [64]byte + +//go:wasmexport alloc +func alloc(size int32) int32 { + _ = size + return int32(uintptr(unsafe.Pointer(&reqBuf[0]))) +} + +//go:wasmexport evaluate +func evaluate(reqPtr int32, reqLen int32) int32 { + _ = reqPtr + + var req struct { + Method string `json:"method"` + Params struct { + Response string `json:"response"` + Answer string `json:"answer"` + Params map[string]any `json:"params"` + } `json:"params"` + } + + if err := json.Unmarshal(reqBuf[:reqLen], &req); err != nil { + writeResp(map[string]any{"error": map[string]any{"message": err.Error()}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + + invocationCount++ + copy(lastResponse[:], req.Params.Response) + + switch req.Method { + case "eval": + correct := req.Params.Response == req.Params.Answer + feedback := "Correct — and the guest counter is still 1, so snapshot/restore worked." + if !correct { + feedback = fmt.Sprintf("Incorrect: got %q, expected %q. Guest counter is still %d.", req.Params.Response, req.Params.Answer, invocationCount) + } + writeResp(map[string]any{ + "command": "eval", + "result": map[string]any{ + "is_correct": correct, + "feedback": feedback, + "guest_invocation_count": invocationCount, + "snapshot_isolation_ok": invocationCount == 1, + }, + }) + case "preview": + writeResp(map[string]any{ + "command": "preview", + "result": map[string]any{ + "preview": map[string]any{"type": "text", "content": req.Params.Response}, + }, + }) + case "healthcheck": + writeResp(map[string]any{ + "command": "healthcheck", + "result": map[string]any{"status": "ok"}, + }) + default: + writeResp(map[string]any{"error": map[string]any{"message": "unknown method: " + req.Method}}) + } + + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) +} + +func writeResp(v map[string]any) { + data, err := json.Marshal(v) + if err != nil { + data = []byte(`{"error":{"message":"marshal failed"}}`) + } + binary.LittleEndian.PutUint32(respBuf[:4], uint32(len(data))) + copy(respBuf[4:], data) +} + +func main() {} diff --git a/examples/lambda-feedback-adapter/lf_compat_adapter.py b/examples/lambda-feedback-adapter/lf_compat_adapter.py new file mode 100644 index 0000000..8046feb --- /dev/null +++ b/examples/lambda-feedback-adapter/lf_compat_adapter.py @@ -0,0 +1,231 @@ +"""Backend-independent compatibility helpers for Lambda Feedback evaluators. + +This adapter is intentionally small and test-only for now: it lets Shimmy's +WASM/Python runtime experiments exercise real Lambda Feedback evaluator shapes +without requiring the production lf_toolkit package or a specific execution +backend. Runtime integrations can call the same helpers from CPython, Pyodide, +or a future Python reactor wrapper. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import importlib +import inspect +import io +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Iterator + + +class EntrypointError(ValueError): + """Raised when an evaluator entrypoint cannot be parsed or loaded.""" + + +@dataclasses.dataclass(frozen=True) +class EvaluatorCallResult: + """Result from a hygienic evaluator call.""" + + result: dict[str, Any] + stdout: str + stderr: str + workdir: str + + +@contextlib.contextmanager +def evaluator_context( + root: str | Path, + *, + env: dict[str, str] | None = None, +) -> Iterator[Path]: + """Run evaluator code from an isolated temporary cwd and restore process state. + + This is best-effort lifecycle hygiene for native Python execution, not a + security sandbox. It restores cwd, ``sys.path``, and environment variables + after the call while cleaning the temporary request workspace. + """ + + previous_cwd = Path.cwd() + previous_sys_path = list(sys.path) + previous_env = dict(os.environ) + with tempfile.TemporaryDirectory(prefix="lf-eval-") as workdir: + try: + if env: + os.environ.update(env) + os.chdir(workdir) + yield Path(workdir) + finally: + os.chdir(previous_cwd) + sys.path[:] = previous_sys_path + os.environ.clear() + os.environ.update(previous_env) + + +def run_entrypoint( + root: str | Path, + entrypoint: str, + *, + method: str, + response: Any, + answer: Any = None, + params: dict[str, Any] | None = None, + env: dict[str, str] | None = None, +) -> EvaluatorCallResult: + """Load and call an evaluator entrypoint inside ``evaluator_context``. + + Captures evaluator stdout/stderr and removes modules imported from the + evaluator root afterwards so repeated package-shaped calls do not share + accidental import state. + """ + + stdout = io.StringIO() + stderr = io.StringIO() + root_path = str(Path(root).resolve()) + module_name = entrypoint.split(":", 1)[0] if ":" in entrypoint else entrypoint + with evaluator_context(root, env=env) as workdir: + workdir_text = str(workdir) + try: + with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr): + fn = load_entrypoint(root, entrypoint) + result = call_function( + fn, + method=method, + response=response, + answer=answer, + params=params, + ) + finally: + _evict_package_modules(module_name, root_path) + return EvaluatorCallResult( + result=result, + stdout=stdout.getvalue(), + stderr=stderr.getvalue(), + workdir=workdir_text, + ) + + +def load_entrypoint(root: str | Path, entrypoint: str) -> Callable[..., Any]: + """Load ``module:function`` from an evaluator package root. + + ``root`` is prepended to ``sys.path`` for the import duration and left there + so relative imports and later lazy imports inside the evaluator keep working. + Existing modules for the same entrypoint are evicted to avoid stale state + when tests invoke multiple fixtures with the same package name. + """ + + if ":" not in entrypoint: + raise EntrypointError(f"entrypoint must be 'module:function', got {entrypoint!r}") + module_name, func_name = entrypoint.split(":", 1) + if not module_name or not func_name: + raise EntrypointError(f"entrypoint must be 'module:function', got {entrypoint!r}") + + root_path = str(Path(root).resolve()) + if root_path in sys.path: + sys.path.remove(root_path) + sys.path.insert(0, root_path) + + _evict_package_modules(module_name, root_path) + module = importlib.import_module(module_name) + try: + fn = getattr(module, func_name) + except AttributeError as exc: + raise EntrypointError(f"entrypoint function {func_name!r} not found in {module_name!r}") from exc + if not callable(fn): + raise EntrypointError(f"entrypoint {entrypoint!r} is not callable") + return fn + + +def call_function( + fn: Callable[..., Any], + *, + method: str, + response: Any, + answer: Any = None, + params: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Invoke an LF evaluator function and normalize its result. + + Eval functions normally accept ``(response, answer, params)``. Some preview + functions in existing repositories accept ``(response, params)`` instead; + this helper detects that shape and preserves compatibility. + """ + + params = params or {} + if method == "preview" and _preview_prefers_two_args(fn): + raw = fn(response, params) + else: + raw = fn(response, answer, params) + normalized = normalize_result(raw) + if not isinstance(normalized, dict): + raise TypeError(f"normalized evaluator result must be an object, got {type(normalized).__name__}") + return normalized + + +def normalize_result(value: Any) -> Any: + """Convert common LF/toolkit/Python result shapes to JSON-compatible data.""" + + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple, set)): + return [normalize_result(v) for v in value] + if isinstance(value, dict): + return {str(k): normalize_result(v) for k, v in value.items() if v is not None} + if dataclasses.is_dataclass(value): + return normalize_result(dataclasses.asdict(value)) + if hasattr(value, "model_dump") and callable(value.model_dump): + return normalize_result(value.model_dump()) + if hasattr(value, "dict") and callable(value.dict): + return normalize_result(value.dict()) + if hasattr(value, "item") and callable(value.item): + try: + return normalize_result(value.item()) + except Exception: + pass + + public = { + name: attr + for name in dir(value) + if not name.startswith("_") + for attr in [getattr(value, name)] + if not callable(attr) + } + if public: + return normalize_result(public) + return value + + +def _preview_prefers_two_args(fn: Callable[..., Any]) -> bool: + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return False + positional = [ + p + for p in sig.parameters.values() + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + has_varargs = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()) + if has_varargs or len(positional) < 2: + return False + if len(positional) == 2: + return True + # Some preview helpers are written as preview(response, params=None). Avoid + # mistaking preview(response, answer=None, params=None) for that shape. + return positional[1].name in {"params", "parameters", "preview_params"} + + +def _evict_package_modules(module_name: str, root_path: str) -> None: + package = module_name.split(".", 1)[0] + for name, module in list(sys.modules.items()): + if name != package and not name.startswith(package + "."): + continue + file = getattr(module, "__file__", None) + # Evaluator fixtures frequently reuse the same top-level package name + # (``evaluation_function``). Evict any prior copy of that package before + # importing from the requested root so relative imports don't resolve to + # a stale fixture from an earlier invocation. + if file and (package == "evaluation_function" or str(Path(file).resolve()).startswith(root_path)): + del sys.modules[name] diff --git a/examples/lambda-feedback-adapter/lf_file_worker.py b/examples/lambda-feedback-adapter/lf_file_worker.py new file mode 100644 index 0000000..623ccca --- /dev/null +++ b/examples/lambda-feedback-adapter/lf_file_worker.py @@ -0,0 +1,98 @@ +"""Shimmy file-IO worker for Lambda Feedback-compatible Python evaluators. + +Shimmy's file adapter sends JSON as:: + + {"command": "eval"|"preview", "params": {...request body...}} + +This worker loads a Python ``module:function`` entrypoint from an evaluator root, +invokes it through ``lf_compat_adapter``, and writes Shimmy's schema-compatible +response envelope back to the response file. +""" + +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Any + +from lf_compat_adapter import run_entrypoint + + +def handle_message(message: dict[str, Any]) -> dict[str, Any]: + command = str(message.get("command") or "eval") + params = message.get("params") + if not isinstance(params, dict): + return error_response("request params must be an object") + + try: + root, entrypoint, evaluator_params = resolve_config(command, params) + call = run_entrypoint( + root, + entrypoint, + method=command, + response=params.get("response"), + answer=params.get("answer"), + params=evaluator_params, + ) + return {"command": command, "result": call.result} + except Exception as exc: # Keep worker failures schema-compatible for Shimmy. + return error_response(str(exc), exc.__class__.__name__) + + +def resolve_config(command: str, request_params: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: + raw_params = request_params.get("params") + evaluator_params = dict(raw_params) if isinstance(raw_params, dict) else {} + + root = pop_first(evaluator_params, ("_lf_root", "root")) or os.getenv("FUNCTION_LF_ROOT") + entrypoint = ( + pop_first(evaluator_params, ("_lf_entrypoint", "entrypoint")) + or os.getenv(f"FUNCTION_LF_{command.upper()}_ENTRYPOINT") + or os.getenv("FUNCTION_LF_ENTRYPOINT") + ) + + missing = [] + if not entrypoint: + missing.append("entrypoint") + if not root: + missing.append("root") + if missing: + raise ValueError(f"missing evaluator params: {', '.join(sorted(missing))}") + + return str(root), str(entrypoint), evaluator_params + + +def pop_first(data: dict[str, Any], keys: tuple[str, ...]) -> Any: + for key in keys: + if key in data: + return data.pop(key) + return None + + +def error_response(message: str, error_type: str | None = None) -> dict[str, Any]: + error: dict[str, Any] = {"message": message} + if error_type: + error["error_thrown"] = error_type + return {"error": error} + + +def main(argv: list[str] | None = None) -> int: + argv = list(sys.argv[1:] if argv is None else argv) + request_file = os.getenv("EVAL_FILE_NAME_REQUEST") or (argv[0] if len(argv) >= 1 else None) + response_file = os.getenv("EVAL_FILE_NAME_RESPONSE") or (argv[1] if len(argv) >= 2 else None) + if not request_file or not response_file: + print("EVAL_FILE_NAME_REQUEST and EVAL_FILE_NAME_RESPONSE are required", file=sys.stderr) + return 2 + + with Path(request_file).open("r", encoding="utf-8") as f: + message = json.load(f) + response = handle_message(message) + with Path(response_file).open("w", encoding="utf-8") as f: + json.dump(response, f, separators=(",", ":"), sort_keys=True) + f.write("\n") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/lambda-feedback-adapter/lf_toolkit/__init__.py b/examples/lambda-feedback-adapter/lf_toolkit/__init__.py new file mode 100644 index 0000000..826ec92 --- /dev/null +++ b/examples/lambda-feedback-adapter/lf_toolkit/__init__.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from .evaluation import Result +from .preview import Preview + + +class Server: + def __init__(self): + self.eval_function = None + self.preview_function = None + + def eval(self, fn): + self.eval_function = fn + return fn + + def preview(self, fn): + self.preview_function = fn + return fn + + +def create_server(): + return Server() + + +def run(server): + raise RuntimeError("test lf_toolkit shim does not implement an HTTP server") + + +__all__ = ["Preview", "Result", "Server", "create_server", "run"] diff --git a/examples/lambda-feedback-adapter/lf_toolkit/evaluation.py b/examples/lambda-feedback-adapter/lf_toolkit/evaluation.py new file mode 100644 index 0000000..21020da --- /dev/null +++ b/examples/lambda-feedback-adapter/lf_toolkit/evaluation.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class Result: + is_correct: bool + feedback: str = "" + metadata: Optional[Dict[str, Any]] = None diff --git a/examples/lambda-feedback-adapter/lf_toolkit/preview.py b/examples/lambda-feedback-adapter/lf_toolkit/preview.py new file mode 100644 index 0000000..20bf8dd --- /dev/null +++ b/examples/lambda-feedback-adapter/lf_toolkit/preview.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class Preview: + markdown: Optional[str] = None + html: Optional[str] = None + data: Optional[Dict[str, Any]] = None + + +@dataclass +class Result: + preview: Preview + metadata: Optional[Dict[str, Any]] = None diff --git a/examples/lambda-feedback-adapter/run_lf_eval.py b/examples/lambda-feedback-adapter/run_lf_eval.py new file mode 100755 index 0000000..2b1b8ae --- /dev/null +++ b/examples/lambda-feedback-adapter/run_lf_eval.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +THIS_DIR = Path(__file__).resolve().parent +if str(THIS_DIR) not in sys.path: + sys.path.insert(0, str(THIS_DIR)) + +from lf_compat_adapter import call_function, load_entrypoint + + +def main() -> int: + parser = argparse.ArgumentParser(description="Run a Lambda Feedback evaluator fixture locally") + parser.add_argument("--root", required=True, help="Evaluator package root") + parser.add_argument("--entrypoint", required=True, help="module:function entrypoint") + parser.add_argument("--method", choices=["eval", "preview"], default="eval") + parser.add_argument("--response", default="") + parser.add_argument("--answer", default=None) + parser.add_argument("--params-json", default="{}") + args = parser.parse_args() + + params = json.loads(args.params_json) + if not isinstance(params, dict): + raise SystemExit("--params-json must decode to a JSON object") + + fn = load_entrypoint(args.root, args.entrypoint) + result = call_function( + fn, + method=args.method, + response=args.response, + answer=args.answer, + params=params, + ) + print(json.dumps(result, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/lambda-feedback-adapter/test_lf_compat_adapter.py b/examples/lambda-feedback-adapter/test_lf_compat_adapter.py new file mode 100644 index 0000000..bd69218 --- /dev/null +++ b/examples/lambda-feedback-adapter/test_lf_compat_adapter.py @@ -0,0 +1,200 @@ +import json +import os +import subprocess +import sys +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path + +ADAPTER_DIR = Path(__file__).resolve().parent +FIXTURE_DIR = ADAPTER_DIR.parent / "lambda-feedback-fixtures" +if str(ADAPTER_DIR) not in sys.path: + sys.path.insert(0, str(ADAPTER_DIR)) + +from lf_compat_adapter import call_function, load_entrypoint, normalize_result, run_entrypoint + + +class LFCompatAdapterTests(unittest.TestCase): + def test_eval_entrypoint_normalizes_toolkit_result(self): + fn = load_entrypoint( + FIXTURE_DIR / "boilerplate-python", + "evaluation_function.main:evaluation_function", + ) + + result = call_function( + fn, + method="eval", + response=" 42 ", + answer="42", + params={}, + ) + + self.assertEqual( + result, + {"is_correct": True, "feedback": "Correct"}, + ) + + def test_preview_supports_two_argument_signature_and_relative_imports(self): + fn = load_entrypoint( + FIXTURE_DIR / "relative-preview", + "evaluation_function.evaluation:preview_function", + ) + + result = call_function( + fn, + method="preview", + response=" Foo ", + answer="unused", + params={"mode": "set"}, + ) + + self.assertEqual(result, {"preview": {"markdown": "Preview: foo / set"}}) + + def test_preview_supports_optional_two_argument_signature(self): + def preview_function(response, params=None): + return {"preview": {"response": response, "mode": (params or {}).get("mode")}} + + result = call_function( + preview_function, + method="preview", + response="draft", + answer="unused", + params={"mode": "optional-two-arg"}, + ) + + self.assertEqual(result, {"preview": {"response": "draft", "mode": "optional-two-arg"}}) + + def test_preview_keeps_three_argument_signature_with_default_answer(self): + def preview_function(response, answer=None, params=None): + return {"preview": {"response": response, "answer": answer, "mode": (params or {}).get("mode")}} + + result = call_function( + preview_function, + method="preview", + response="draft", + answer="expected", + params={"mode": "three-arg"}, + ) + + self.assertEqual(result, {"preview": {"response": "draft", "answer": "expected", "mode": "three-arg"}}) + + def test_eval_supports_relative_import_fixture(self): + fn = load_entrypoint( + FIXTURE_DIR / "relative-preview", + "evaluation_function.evaluation:evaluation_function", + ) + + result = call_function( + fn, + method="eval", + response="a, b", + answer="b,a", + params={"mode": "set"}, + ) + + self.assertEqual(result["is_correct"], True) + self.assertEqual(result["feedback"], "matched") + + def test_normalize_result_handles_common_object_shapes_and_scalar_items(self): + @dataclass + class DataResult: + is_correct: bool + score: object + + class Scalar: + def item(self): + return 7 + + self.assertEqual( + normalize_result(DataResult(is_correct=True, score=Scalar())), + {"is_correct": True, "score": 7}, + ) + + def test_run_entrypoint_uses_temporary_workspace_and_restores_process_state(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) / "evaluator" + package = root / "evaluation_function" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "main.py").write_text( + """ +import os +from pathlib import Path + + +def evaluation_function(response, answer, params): + print("hello from evaluator") + print("warn from evaluator", file=__import__("sys").stderr) + cwd = Path.cwd() + (cwd / "scratch.txt").write_text("temporary", encoding="utf-8") + os.environ["LF_CONTEXT_MUTATION"] = "dirty" + return { + "cwd_is_request_workspace": cwd.name.startswith("lf-eval-"), + "mode": params.get("mode"), + "response": response, + } +""", + encoding="utf-8", + ) + before_cwd = Path.cwd() + before_sys_path = list(sys.path) + old_env = os.environ.get("LF_CONTEXT_MUTATION") + + call = run_entrypoint( + root, + "evaluation_function.main:evaluation_function", + method="eval", + response="draft", + answer="expected", + params={"mode": "hygiene"}, + ) + + self.assertEqual( + call.result, + { + "cwd_is_request_workspace": True, + "mode": "hygiene", + "response": "draft", + }, + ) + self.assertEqual(call.stdout.strip(), "hello from evaluator") + self.assertEqual(call.stderr.strip(), "warn from evaluator") + self.assertEqual(Path.cwd(), before_cwd) + self.assertEqual(sys.path, before_sys_path) + self.assertFalse(Path(call.workdir).exists()) + self.assertFalse((root / "scratch.txt").exists()) + self.assertEqual(os.environ.get("LF_CONTEXT_MUTATION"), old_env) + + def test_cli_runner_outputs_normalized_json(self): + runner = ADAPTER_DIR / "run_lf_eval.py" + completed = subprocess.run( + [ + sys.executable, + str(runner), + "--root", + str(FIXTURE_DIR / "boilerplate-python"), + "--entrypoint", + "evaluation_function.main:evaluation_function", + "--method", + "eval", + "--response", + "no", + "--answer", + "yes", + "--params-json", + "{}", + ], + check=True, + text=True, + stdout=subprocess.PIPE, + ) + + self.assertEqual( + json.loads(completed.stdout), + {"is_correct": False, "feedback": "Try again"}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/lambda-feedback-adapter/test_lf_file_worker.py b/examples/lambda-feedback-adapter/test_lf_file_worker.py new file mode 100644 index 0000000..2360b38 --- /dev/null +++ b/examples/lambda-feedback-adapter/test_lf_file_worker.py @@ -0,0 +1,203 @@ +import json +import os +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +ADAPTER_DIR = Path(__file__).resolve().parent +if str(ADAPTER_DIR) not in sys.path: + sys.path.insert(0, str(ADAPTER_DIR)) + +import lf_file_worker + +BOILERPLATE_ROOT = ROOT / "lambda-feedback-fixtures" / "boilerplate-python" +RELATIVE_ROOT = ROOT / "lambda-feedback-fixtures" / "relative-preview" + + +class LFFileWorkerTest(unittest.TestCase): + def with_env(self, updates, removed=()): + class EnvGuard: + def __enter__(inner_self): + keys = set(updates) | set(removed) + inner_self.old_env = {key: os.environ.get(key) for key in keys} + for key in removed: + os.environ.pop(key, None) + for key, value in updates.items(): + os.environ[key] = value + + def __exit__(inner_self, exc_type, exc, tb): + for key, value in inner_self.old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + return EnvGuard() + + def test_handle_message_wraps_eval_result_for_shimmy_file_adapter(self): + message = { + "command": "eval", + "params": { + "response": "answer", + "answer": "answer", + "params": {"root": str(BOILERPLATE_ROOT), "entrypoint": "evaluation_function.main:evaluation_function"}, + }, + } + + result = lf_file_worker.handle_message(message) + + self.assertEqual( + result, + {"command": "eval", "result": {"is_correct": True, "feedback": "Correct"}}, + ) + + def test_handle_message_supports_preview_entrypoint_and_params(self): + message = { + "command": "preview", + "params": { + "response": "foo", + "params": {"root": str(RELATIVE_ROOT), "entrypoint": "evaluation_function.evaluation:preview_function", "mode": "bar"}, + }, + } + + result = lf_file_worker.handle_message(message) + + self.assertEqual(result, {"command": "preview", "result": {"preview": {"markdown": "Preview: foo / bar"}}}) + + def test_handle_message_returns_schema_compatible_error(self): + result = lf_file_worker.handle_message({"command": "eval", "params": {"response": "x"}}) + + self.assertEqual(result["error"]["message"], "missing evaluator params: entrypoint, root") + self.assertNotIn("result", result) + + def test_handle_message_ignores_legacy_short_env_names(self): + with self.with_env( + { + "LF_EVAL_ROOT": str(BOILERPLATE_ROOT), + "LF_EVAL_ENTRYPOINT": "evaluation_function.main:evaluation_function", + }, + removed=("FUNCTION_LF_ROOT", "FUNCTION_LF_ENTRYPOINT", "FUNCTION_LF_PREVIEW_ENTRYPOINT", "FUNCTION_LF_EVAL_ENTRYPOINT"), + ): + result = lf_file_worker.handle_message( + {"command": "eval", "params": {"response": "42", "answer": "42", "params": {}}} + ) + + self.assertEqual(result["error"]["message"], "missing evaluator params: entrypoint, root") + + def test_handle_message_uses_function_lf_env_names_for_package_mode(self): + with self.with_env( + { + "FUNCTION_LF_ROOT": str(BOILERPLATE_ROOT), + "FUNCTION_LF_ENTRYPOINT": "evaluation_function.main:evaluation_function", + }, + removed=("LF_EVAL_ROOT", "LF_EVAL_ENTRYPOINT", "LF_ENTRYPOINT"), + ): + result = lf_file_worker.handle_message( + {"command": "eval", "params": {"response": "42", "answer": "42", "params": {}}} + ) + + self.assertEqual(result, {"command": "eval", "result": {"is_correct": True, "feedback": "Correct"}}) + + def test_handle_message_uses_command_specific_function_lf_preview_entrypoint(self): + with self.with_env( + { + "FUNCTION_LF_ROOT": str(BOILERPLATE_ROOT), + "FUNCTION_LF_ENTRYPOINT": "evaluation_function.main:evaluation_function", + "FUNCTION_LF_PREVIEW_ENTRYPOINT": "evaluation_function.main:preview_function", + }, + removed=("LF_EVAL_ROOT", "LF_EVAL_ENTRYPOINT", "LF_PREVIEW_ENTRYPOINT", "LF_ENTRYPOINT"), + ): + result = lf_file_worker.handle_message( + {"command": "preview", "params": {"response": "draft", "params": {"expected": "answer"}}} + ) + + self.assertEqual( + result, + {"command": "preview", "result": {"preview": {"response": "draft", "expected": "answer"}}}, + ) + + def test_handle_message_runs_evaluator_with_hygiene_context(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) / "evaluator" + package = root / "evaluation_function" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "main.py").write_text( + """ +import os +from pathlib import Path + + +def evaluation_function(response, answer, params): + print("worker stdout") + Path("worker-scratch.txt").write_text("scratch", encoding="utf-8") + os.environ["LF_WORKER_MUTATION"] = "dirty" + return {"cwd_is_workspace": Path.cwd().name.startswith("lf-eval-"), "response": response} +""", + encoding="utf-8", + ) + before_cwd = Path.cwd() + scratch = before_cwd / "worker-scratch.txt" + if scratch.exists(): + scratch.unlink() + old_env = os.environ.get("LF_WORKER_MUTATION") + + result = lf_file_worker.handle_message( + { + "command": "eval", + "params": { + "response": "draft", + "answer": "expected", + "params": {"root": str(root), "entrypoint": "evaluation_function.main:evaluation_function"}, + }, + } + ) + + self.assertEqual(result, {"command": "eval", "result": {"cwd_is_workspace": True, "response": "draft"}}) + self.assertEqual(Path.cwd(), before_cwd) + self.assertFalse(scratch.exists()) + self.assertEqual(os.environ.get("LF_WORKER_MUTATION"), old_env) + + def test_cli_reads_request_file_and_writes_response_file(self): + with tempfile.TemporaryDirectory() as tmp: + req = Path(tmp) / "request.json" + res = Path(tmp) / "response.json" + req.write_text( + json.dumps( + { + "command": "eval", + "params": { + "response": "a, b", + "answer": "b,a", + "params": {"root": str(RELATIVE_ROOT), "entrypoint": "evaluation_function.evaluation:evaluation_function", "mode": "set"}, + }, + } + ), + encoding="utf-8", + ) + env = os.environ.copy() + env["EVAL_FILE_NAME_REQUEST"] = str(req) + env["EVAL_FILE_NAME_RESPONSE"] = str(res) + + completed = subprocess.run( + [sys.executable, str(ADAPTER_DIR / "lf_file_worker.py")], + cwd=str(ADAPTER_DIR), + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + self.assertEqual(completed.returncode, 0, completed.stderr) + self.assertEqual( + json.loads(res.read_text(encoding="utf-8")), + {"command": "eval", "result": {"is_correct": True, "feedback": "matched"}}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/lambda-feedback-fixtures/README.md b/examples/lambda-feedback-fixtures/README.md new file mode 100644 index 0000000..8f56b1b --- /dev/null +++ b/examples/lambda-feedback-fixtures/README.md @@ -0,0 +1,39 @@ +# Lambda Feedback compatibility fixtures + +These small fixtures mirror common shapes from real Lambda Feedback evaluator +repositories while staying deterministic and dependency-light. + +- `boilerplate-python`: toolkit-style `create_server()`, `@server.eval`, + `@server.preview`, and `lf_toolkit.evaluation.Result`. +- `relative-preview`: package-relative imports plus a two-argument + `preview_function(response, params)` variant. + +They are compatibility fixtures, not production evaluators. The local adapter +uses the test-only `lf_toolkit` shim under `examples/lambda-feedback-adapter/`. + +## Local adapter and Shimmy file-worker smoke + +Exercise the backend-independent adapter directly: + +```bash +scripts/demo-lambda-feedback-fixtures.sh all +``` + +Exercise the same fixtures through Shimmy's existing `FUNCTION_INTERFACE=file` +path: + +```bash +scripts/demo-lambda-feedback-file-worker.sh all +``` + +The file worker uses only three package-mode envs: `FUNCTION_LF_ROOT`, +`FUNCTION_LF_ENTRYPOINT`, and optional `FUNCTION_LF_PREVIEW_ENTRYPOINT` when +preview lives at a different function. Per-request fixture tests may also pass +`root` and `entrypoint` inside the request `params` object; those keys are +removed before calling the evaluator. + +File-worker calls now run through the adapter's best-effort evaluator hygiene +context: per-request temporary cwd, cwd/`sys.path`/environment restore, evaluator +stdout/stderr capture, and cleanup of modules imported from the evaluator root. +This reduces accidental pollution in demos and CI. It is not a security sandbox; +malicious native Python evaluators still require a real isolation boundary. diff --git a/examples/lambda-feedback-fixtures/boilerplate-python/evaluation_function/__init__.py b/examples/lambda-feedback-fixtures/boilerplate-python/evaluation_function/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/lambda-feedback-fixtures/boilerplate-python/evaluation_function/main.py b/examples/lambda-feedback-fixtures/boilerplate-python/evaluation_function/main.py new file mode 100644 index 0000000..0a6af89 --- /dev/null +++ b/examples/lambda-feedback-fixtures/boilerplate-python/evaluation_function/main.py @@ -0,0 +1,20 @@ +from lf_toolkit import create_server, run +from lf_toolkit.evaluation import Result + +server = create_server() + + +@server.eval +def evaluation_function(response, answer, params): + expected = params.get("expected", answer) + is_correct = str(response).strip().lower() == str(expected).strip().lower() + return Result(is_correct=is_correct, feedback="Correct" if is_correct else "Try again") + + +@server.preview +def preview_function(response, answer, params): + return {"preview": {"response": response, "answer": answer, "expected": params.get("expected")}} + + +if __name__ == "__main__": + run(server) diff --git a/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/__init__.py b/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/evaluation.py b/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/evaluation.py new file mode 100644 index 0000000..481b44e --- /dev/null +++ b/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/evaluation.py @@ -0,0 +1,17 @@ +from lf_toolkit.evaluation import Result +from lf_toolkit.preview import Preview, Result as PreviewResult + +from .parse import normalize, split_items + + +def evaluation_function(response, answer, params): + mode = params.get("mode", "exact") + if mode == "set": + is_correct = split_items(response) == split_items(answer) + else: + is_correct = normalize(response) == normalize(answer) + return Result(is_correct=is_correct, feedback="matched" if is_correct else "mismatch") + + +def preview_function(response, params): + return PreviewResult(preview=Preview(markdown=f"Preview: {normalize(response)} / {params.get('mode', 'exact')}")) diff --git a/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/parse.py b/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/parse.py new file mode 100644 index 0000000..9740ea3 --- /dev/null +++ b/examples/lambda-feedback-fixtures/relative-preview/evaluation_function/parse.py @@ -0,0 +1,6 @@ +def normalize(value): + return str(value).strip().lower() + + +def split_items(value): + return {part.strip().lower() for part in str(value).split(',') if part.strip()} diff --git a/go.mod b/go.mod index 10caf84..cee6645 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect + github.com/tetratelabs/wazero v1.9.0 github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/woodsbury/decimal128 v1.3.0 // indirect diff --git a/go.sum b/go.sum index 014f78c..8aeb9f0 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/supranational/blst v0.3.11 h1:LyU6FolezeWAhvQk0k6O/d49jqgO52MSDDfYgbeoEm4= github.com/supranational/blst v0.3.11/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/internal/execution/dispatcher.go b/internal/execution/dispatcher.go index 300ca3f..016ee05 100644 --- a/internal/execution/dispatcher.go +++ b/internal/execution/dispatcher.go @@ -2,11 +2,16 @@ package execution import ( "context" + "fmt" + "os" + "sort" + "strings" "go.uber.org/zap" "github.com/lambda-feedback/shimmy/internal/execution/dispatcher" "github.com/lambda-feedback/shimmy/internal/execution/supervisor" + "github.com/lambda-feedback/shimmy/internal/execution/wasm" ) type Dispatcher dispatcher.Dispatcher @@ -32,7 +37,11 @@ type Params struct { } func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { - if params.Config.Supervisor.IO.Interface == supervisor.RpcIO { + switch params.Config.Supervisor.IO.Interface { + case supervisor.RpcIO: + if err := requireProcessWorkerCommand(params.Config.Supervisor); err != nil { + return nil, err + } return dispatcher.NewDedicatedDispatcher( dispatcher.DedicatedDispatcherParams{ Config: dispatcher.DedicatedDispatcherConfig{ @@ -42,7 +51,35 @@ func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { Log: params.Log, }, ) - } else { + + case supervisor.WasmIO: + wasmProfile := strings.ToLower(strings.TrimSpace(os.Getenv("FUNCTION_WASM_PROFILE"))) + if wasmProfile == "" { + wasmProfile = "generic" + } + if wasmProfile != "generic" { + validProfiles := []string{"generic"} + sort.Strings(validProfiles) + return nil, fmt.Errorf("unsupported FUNCTION_WASM_PROFILE %q; supported values: %s", wasmProfile, strings.Join(validProfiles, ", ")) + } + + cfg := wasm.Config{ + ModulePath: params.Config.Supervisor.StartParams.Cmd, + MaxInstances: params.Config.MaxWorkers, + Timeout: params.Config.Supervisor.SendParams.Timeout, + } + d := wasm.NewDispatcher(cfg, params.Log) + if err := d.Start(params.Context); err != nil { + return nil, err + } + return d, nil + + default: + if params.Config.Supervisor.IO.Interface == supervisor.FileIO { + if err := requireProcessWorkerCommand(params.Config.Supervisor); err != nil { + return nil, err + } + } return dispatcher.NewPooledDispatcher( dispatcher.PooledDispatcherParams{ Config: dispatcher.PooledDispatcherConfig{ @@ -55,3 +92,10 @@ func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { ) } } + +func requireProcessWorkerCommand(cfg supervisor.Config) error { + if strings.TrimSpace(cfg.StartParams.Cmd) == "" { + return fmt.Errorf("FUNCTION_COMMAND is required when FUNCTION_INTERFACE=%q", cfg.IO.Interface) + } + return nil +} diff --git a/internal/execution/dispatcher_test.go b/internal/execution/dispatcher_test.go new file mode 100644 index 0000000..3cbd314 --- /dev/null +++ b/internal/execution/dispatcher_test.go @@ -0,0 +1,42 @@ +package execution + +import ( + "context" + "strings" + "testing" + + "go.uber.org/zap" + + "github.com/lambda-feedback/shimmy/internal/execution/supervisor" +) + +func TestNewDispatcher_RequiresCommandForProcessInterfaces(t *testing.T) { + tests := []struct { + name string + io supervisor.IOInterface + }{ + {name: "rpc", io: supervisor.RpcIO}, + {name: "file", io: supervisor.FileIO}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewDispatcher(Params{ + Context: context.Background(), + Config: Config{ + MaxWorkers: 1, + Supervisor: supervisor.Config{ + IO: supervisor.IOConfig{Interface: tt.io}, + }, + }, + Log: zap.NewNop(), + }) + if err == nil { + t.Fatal("expected missing command error") + } + if got, want := err.Error(), "FUNCTION_COMMAND is required"; !strings.Contains(got, want) { + t.Fatalf("expected error to contain %q, got %q", want, got) + } + }) + } +} diff --git a/internal/execution/supervisor/config.go b/internal/execution/supervisor/config.go index 520e367..9d7303d 100644 --- a/internal/execution/supervisor/config.go +++ b/internal/execution/supervisor/config.go @@ -24,7 +24,7 @@ type SendConfig struct { // IOInterface describes the interface used to communicate with the worker. type IOConfig struct { // Interface describes the communication between the supervisor - // and the worker. It can be either "rpc" or "file". + // and the worker. It can be "rpc", "file", or "wasm". // // If "rpc", the supervisor will communicate with the worker over // a specified transport. The worker is expected to handle incoming @@ -35,6 +35,9 @@ type IOConfig struct { // containing the message payload and response are passed as args // to the worker process. // + // If "wasm", Shimmy loads a pre-built WASI module from FUNCTION_COMMAND + // or FUNCTION_WASM_MODULE and calls its internal alloc/evaluate adapter ABI. + // // Default is "rpc". Interface IOInterface `conf:"interface"` diff --git a/internal/execution/supervisor/models.go b/internal/execution/supervisor/models.go index e7776db..8f98bcb 100644 --- a/internal/execution/supervisor/models.go +++ b/internal/execution/supervisor/models.go @@ -16,6 +16,9 @@ const ( // FileIO describes communication w/ processes over files FileIO IOInterface = "file" + + // WasmIO describes in-process execution of a pre-built WASI module. + WasmIO IOInterface = "wasm" ) // IOTransport describes the transport mechanism used to communicate with diff --git a/internal/execution/wasm/adapter.go b/internal/execution/wasm/adapter.go new file mode 100644 index 0000000..3d08774 --- /dev/null +++ b/internal/execution/wasm/adapter.go @@ -0,0 +1,177 @@ +// Package wasm implements a WebAssembly execution backend for shimmy using +// wazero. It exposes a [Dispatcher] that manages a pool of pre-compiled WASM +// module instances and dispatches evaluation requests to them. +// +// # Guest ABI +// +// WASM modules loaded by this backend must export two functions: +// +// alloc(size i32) i32 +// Allocate `size` bytes in guest linear memory and return a pointer to +// the start of the allocation. The host will write the JSON-encoded +// request into this region immediately after the call returns. +// +// evaluate(req_ptr i32, req_len i32) i32 +// Process the JSON request at [req_ptr, req_ptr+req_len). Returns a +// pointer P into guest memory where the response is encoded as: +// bytes [P, P+4) — uint32 little-endian response length L +// bytes [P+4, P+4+L) — L bytes of UTF-8 JSON response +// +// The JSON request envelope has the shape: +// +// {"method": "", "params": {…}} +// +// The JSON response is a plain JSON object (map[string]any) that is returned +// verbatim to the caller. +package wasm + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/tetratelabs/wazero/api" + "go.uber.org/zap" +) + +// requestEnvelope is the JSON structure written into guest memory for each +// evaluation call. +type requestEnvelope struct { + Method string `json:"method"` + Params map[string]any `json:"params"` +} + +// wasmAdapter performs a single evaluate call against a live wazero api.Module. +// It is stateless and safe to call from one goroutine at a time. +type wasmAdapter struct { + mod api.Module + log *zap.Logger + allocFn api.Function // cached exported "alloc" function (M-4 fix) + evalFn api.Function // cached exported "evaluate" function (M-4 fix) +} + +func newWasmAdapter(mod api.Module, log *zap.Logger) *wasmAdapter { + return &wasmAdapter{ + mod: mod, + log: log.Named("adapter_wasm"), + allocFn: mod.ExportedFunction("alloc"), + evalFn: mod.ExportedFunction("evaluate"), + } +} + +// send marshals (method, data) into JSON, writes it into the guest's linear +// memory via alloc, calls evaluate, and reads back the length-prefixed +// response. +func (a *wasmAdapter) send( + ctx context.Context, + method string, + data map[string]any, + timeout time.Duration, +) (map[string]any, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + // 1. Marshal request envelope. + envelope := requestEnvelope{Method: method, Params: data} + + reqBytes, err := json.Marshal(envelope) + if err != nil { + return nil, fmt.Errorf("wasm: marshal request: %w", err) + } + + reqLen := uint64(len(reqBytes)) + + // 2. Allocate guest memory for the request (cached lookup — M-4 fix). + if a.allocFn == nil { + return nil, fmt.Errorf("wasm: guest module does not export 'alloc'") + } + + allocRes, err := a.allocFn.Call(ctx, reqLen) + if err != nil { + return nil, fmt.Errorf("wasm: alloc(%d): %w", reqLen, err) + } + if len(allocRes) != 1 { + return nil, fmt.Errorf("wasm: alloc returned %d values, expected 1", len(allocRes)) + } + + reqPtr := allocRes[0] + if reqPtr == 0 { + return nil, fmt.Errorf("wasm: alloc returned NULL (out of memory)") + } + + // 3. Write request bytes into guest memory. + mem := a.mod.Memory() + if mem == nil { + return nil, fmt.Errorf("wasm: guest module has no linear memory") + } + + if !mem.Write(uint32(reqPtr), reqBytes) { + return nil, fmt.Errorf( + "wasm: failed to write %d bytes at ptr=%d (memory size=%d)", + len(reqBytes), reqPtr, mem.Size(), + ) + } + + // 4. Call evaluate (cached lookup — M-4 fix). + if a.evalFn == nil { + return nil, fmt.Errorf("wasm: guest module does not export 'evaluate'") + } + + a.log.Debug("calling evaluate", + zap.String("method", method), + zap.Uint64("req_ptr", reqPtr), + zap.Uint64("req_len", reqLen), + ) + + evalRes, err := a.evalFn.Call(ctx, reqPtr, reqLen) + if err != nil { + return nil, fmt.Errorf("wasm: evaluate: %w", err) + } + if len(evalRes) != 1 { + return nil, fmt.Errorf("wasm: evaluate returned %d values, expected 1", len(evalRes)) + } + + resPtr := uint32(evalRes[0]) + + // 5. Read the 4-byte little-endian length prefix. + lenBytes, ok := mem.Read(resPtr, 4) + if !ok { + return nil, fmt.Errorf("wasm: failed to read response length at ptr=%d", resPtr) + } + + resLen := binary.LittleEndian.Uint32(lenBytes) + + // 6. Read the response JSON body. + // Validate bounds before reading to catch corrupt/malicious response pointers. + if uint64(resPtr)+4+uint64(resLen) > uint64(mem.Size()) { + return nil, fmt.Errorf( + "wasm: response out of bounds: resPtr=%d resLen=%d memSize=%d", + resPtr, resLen, mem.Size(), + ) + } + resBytes, ok := mem.Read(resPtr+4, resLen) + if !ok { + return nil, fmt.Errorf( + "wasm: failed to read %d response bytes at ptr=%d", + resLen, resPtr+4, + ) + } + + a.log.Debug("received response", + zap.Uint32("res_ptr", resPtr), + zap.Uint32("res_len", resLen), + ) + + // 7. Unmarshal response. + var result map[string]any + if err := json.Unmarshal(resBytes, &result); err != nil { + return nil, fmt.Errorf("wasm: unmarshal response: %w", err) + } + + return result, nil +} diff --git a/internal/execution/wasm/artifact_examples_test.go b/internal/execution/wasm/artifact_examples_test.go new file mode 100644 index 0000000..a25420d --- /dev/null +++ b/internal/execution/wasm/artifact_examples_test.go @@ -0,0 +1,234 @@ +package wasm + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func repoRootFromTest(t *testing.T) string { + t.Helper() + _, filename, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller failed") + return filepath.Clean(filepath.Join(filepath.Dir(filename), "..", "..", "..")) +} + +func buildCppCompareExample(t *testing.T) string { + t.Helper() + zig, err := exec.LookPath("zig") + if err != nil { + t.Skip("zig is required to compile the C++ WASM example") + } + + root := repoRootFromTest(t) + src := filepath.Join(root, "examples", "demo-cpp-compare", "evaluator.cpp") + out := filepath.Join(t.TempDir(), "eval.wasm") + + cmd := exec.Command(zig, "c++", + "-target", "wasm32-freestanding", + "-Oz", + "-nostdlib", + "-fno-exceptions", + "-fno-rtti", + "-Wl,--no-entry", + "-Wl,--export=alloc", + "-Wl,--export=evaluate", + "-Wl,--export-memory", + "-Wl,--initial-memory=2097152", + "-o", out, + src, + ) + output, err := cmd.CombinedOutput() + require.NoError(t, err, "compile C++ WASM example:\n%s", string(output)) + return out +} + +func buildGoStatefulExample(t *testing.T) string { + t.Helper() + goBin, err := exec.LookPath("go") + if err != nil { + t.Skip("go is required to compile the Go WASM example") + } + + root := repoRootFromTest(t) + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(goBin, "build", "-buildmode=c-shared", "-o", out, "./examples/demo-stateful") + cmd.Dir = root + cmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm") + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "requires go1.24 or later") { + t.Skipf("go toolchain does not support //go:wasmexport:\n%s", string(output)) + } + require.NoError(t, err, "compile Go WASM example:\n%s", string(output)) + return out +} + +func buildRustCompareExample(t *testing.T) string { + t.Helper() + rustc, err := exec.LookPath("rustc") + if err != nil { + t.Skip("rustc is required to compile the Rust WASM example") + } + + root := repoRootFromTest(t) + src := filepath.Join(root, "examples", "demo-rust-compare", "evaluator.rs") + require.FileExists(t, src, "Rust WASM example source must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(rustc, + "--target", "wasm32-unknown-unknown", + "--crate-type", "cdylib", + "-C", "panic=abort", + "-O", + "-o", out, + src, + ) + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "target may not be installed") { + t.Skipf("rust target wasm32-unknown-unknown is not installed:\n%s", string(output)) + } + require.NoError(t, err, "compile Rust WASM example:\n%s", string(output)) + return out +} + +func buildGoPackageExample(t *testing.T) string { + t.Helper() + goBin, err := exec.LookPath("go") + if err != nil { + t.Skip("go is required to compile the Go package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-go-package") + require.DirExists(t, packageDir, "Go package WASM example must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(goBin, "build", "-buildmode=c-shared", "-o", out, "./cmd/evaluator") + cmd.Dir = packageDir + cmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm") + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "requires go1.24 or later") { + t.Skipf("go toolchain does not support //go:wasmexport:\n%s", string(output)) + } + require.NoError(t, err, "compile Go package WASM example:\n%s", string(output)) + return out +} + +func buildRustPackageExample(t *testing.T) string { + t.Helper() + cargo, err := exec.LookPath("cargo") + if err != nil { + t.Skip("cargo is required to compile the Rust package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-rust-package") + require.DirExists(t, packageDir, "Rust package WASM example must exist") + cmd := exec.Command(cargo, "build", "--target", "wasm32-unknown-unknown", "--release") + cmd.Dir = packageDir + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "target may not be installed") { + t.Skipf("rust target wasm32-unknown-unknown is not installed:\n%s", string(output)) + } + require.NoError(t, err, "compile Rust package WASM example:\n%s", string(output)) + return filepath.Join(packageDir, "target", "wasm32-unknown-unknown", "release", "demo_rust_package.wasm") +} + +func buildCppPackageExample(t *testing.T) string { + t.Helper() + if _, err := exec.LookPath("zig"); err != nil { + t.Skip("zig is required to compile the C++ package WASM example") + } + if _, err := exec.LookPath("make"); err != nil { + t.Skip("make is required to compile the C++ package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-cpp-package") + require.DirExists(t, packageDir, "C++ package WASM example must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command("make", "wasm", "OUT="+out) + cmd.Dir = packageDir + output, err := cmd.CombinedOutput() + require.NoError(t, err, "compile C++ package WASM example:\n%s", string(output)) + return out +} + +func assertCompareEvaluatorRunsThroughDispatcher(t *testing.T, modulePath string) { + t.Helper() + + d := NewDispatcher(Config{ + ModulePath: modulePath, + MaxInstances: 1, + Timeout: 5 * time.Second, + }, newTestLogger(t)) + require.NoError(t, d.Start(context.Background())) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + correct, err := d.Send(context.Background(), "eval", map[string]any{ + "response": "42", + "answer": "42", + "params": map[string]any{ + "correct_response_feedback": "Correct!", + "incorrect_response_feedback": "Try again.", + }, + }) + require.NoError(t, err) + + wrong, err := d.Send(context.Background(), "eval", map[string]any{ + "response": "41", + "answer": "42", + "params": map[string]any{ + "correct_response_feedback": "Correct!", + "incorrect_response_feedback": "Try again.", + }, + }) + require.NoError(t, err) + + correctResult, ok := correct["result"].(map[string]any) + require.True(t, ok, "correct response result must be an object: %#v", correct) + wrongResult, ok := wrong["result"].(map[string]any) + require.True(t, ok, "wrong response result must be an object: %#v", wrong) + + assert.Equal(t, "eval", correct["command"]) + assert.Equal(t, true, correctResult["is_correct"]) + assert.NotEmpty(t, correctResult["feedback"]) + assert.EqualValues(t, 1, correctResult["guest_invocation_count"]) + assert.Equal(t, true, correctResult["snapshot_isolation_ok"]) + + assert.Equal(t, "eval", wrong["command"]) + assert.Equal(t, false, wrongResult["is_correct"]) + assert.NotEmpty(t, wrongResult["feedback"]) + assert.EqualValues(t, 1, wrongResult["guest_invocation_count"]) + assert.Equal(t, true, wrongResult["snapshot_isolation_ok"]) +} + +func TestCppCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildCppCompareExample(t)) +} + +func TestGoStatefulExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildGoStatefulExample(t)) +} + +func TestRustCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildRustCompareExample(t)) +} + +func TestGoPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildGoPackageExample(t)) +} + +func TestRustPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildRustPackageExample(t)) +} + +func TestCppPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildCppPackageExample(t)) +} diff --git a/internal/execution/wasm/config.go b/internal/execution/wasm/config.go new file mode 100644 index 0000000..952fc93 --- /dev/null +++ b/internal/execution/wasm/config.go @@ -0,0 +1,86 @@ +package wasm + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +// Config holds the configuration for the opt-in generic WASM execution +// backend. Shimmy consumes an already-built WASI module; source-language +// compilation remains a deployment/build concern outside this package. +type Config struct { + // ModulePath is the path to the .wasm module. It is normally populated from + // FUNCTION_COMMAND for compatibility with the rest of Shimmy, or overridden + // by FUNCTION_WASM_MODULE when set. + ModulePath string + + // MaxInstances is the number of warm module instances in the pool. + MaxInstances int + + // Timeout is the per-request deadline passed to the guest evaluate call. + Timeout time.Duration + + // MaxMemoryPages limits WASM linear memory (1 page = 64 KiB). The default is + // intentionally small and can be raised by FUNCTION_WASM_MAX_MEMORY_PAGES. + MaxMemoryPages uint32 + + // AllowedPaths is a comma-separated allowlist of read-only host directories + // exposed to WASI. Empty means no filesystem access. + AllowedPaths []string + + // AllowedEnv is a comma-separated allowlist of host environment variable + // names exposed to WASI. Empty means no environment variables. + AllowedEnv []string + + // CompileCacheDir enables wazero's on-disk compilation cache when set via + // FUNCTION_WASM_COMPILE_CACHE. + CompileCacheDir string +} + +func (c *Config) applyDefaults() { + if c.Timeout == 0 { + c.Timeout = 30 * time.Second + } + if c.MaxMemoryPages == 0 { + c.MaxMemoryPages = 256 // 16 MiB + } +} + +// applyEnv reads FUNCTION_WASM_* overrides. These settings are intentionally +// limited to generic WASM runtime concerns; Python/reactor/package bundling is +// out of scope for this backend. +func (c *Config) applyEnv() error { + if v := os.Getenv("FUNCTION_WASM_MODULE"); v != "" { + c.ModulePath = v + } + if v := os.Getenv("FUNCTION_WASM_MAX_MEMORY_PAGES"); v != "" { + n, err := strconv.ParseUint(v, 10, 32) + if err != nil { + return fmt.Errorf("FUNCTION_WASM_MAX_MEMORY_PAGES must be an unsigned 32-bit integer: %w", err) + } + c.MaxMemoryPages = uint32(n) + } + if v := os.Getenv("FUNCTION_WASM_ALLOWED_PATHS"); v != "" { + c.AllowedPaths = splitNonEmpty(v, ",") + } + if v := os.Getenv("FUNCTION_WASM_ALLOWED_ENV"); v != "" { + c.AllowedEnv = splitNonEmpty(v, ",") + } + if v := os.Getenv("FUNCTION_WASM_COMPILE_CACHE"); v != "" { + c.CompileCacheDir = v + } + return nil +} + +func splitNonEmpty(s, sep string) []string { + var out []string + for _, p := range strings.Split(s, sep) { + if t := strings.TrimSpace(p); t != "" { + out = append(out, t) + } + } + return out +} diff --git a/internal/execution/wasm/dispatcher.go b/internal/execution/wasm/dispatcher.go new file mode 100644 index 0000000..39ba824 --- /dev/null +++ b/internal/execution/wasm/dispatcher.go @@ -0,0 +1,386 @@ +package wasm + +import ( + "context" + "fmt" + "os" + "runtime" + "sync" + "time" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "go.uber.org/zap" + + "github.com/lambda-feedback/shimmy/internal/execution/dispatcher" +) + +// ErrDispatcherClosed is returned by Send after the dispatcher has begun (or +// completed) Shutdown. Callers should treat it as a terminal error. +var ErrDispatcherClosed = fmt.Errorf("wasm: dispatcher is shut down") + +// Dispatcher implements [dispatcher.Dispatcher] for the WASM execution +// backend. It compiles the .wasm module once at startup, then maintains a pool +// of pre-initialised [wasmSupervisor] instances (one compiled module, N module +// instances). Requests are dispatched by acquiring a supervisor from the pool, +// calling its Send, and returning it to the pool. +type Dispatcher struct { + cfg Config + rt wazero.Runtime + cache wazero.CompilationCache + compiled wazero.CompiledModule + modCfg wazero.ModuleConfig + pool chan *wasmSupervisor + log *zap.Logger + + // mu protects closed and serialises the closed/push transitions so that a + // replacement supervisor cannot land in the pool after Shutdown has begun + // draining it. + mu sync.Mutex + closed bool + // closedCh is closed atomically with closed=true (under mu) by Shutdown. + // Send selects on it to (a) unblock a pool acquire that is racing Shutdown + // and (b) avoid waiting on an empty pool that Shutdown is about to drain. + closedCh chan struct{} + // pending tracks BOTH in-flight Sends (Add in tryBeginSend, Done via Send's + // defer) AND background goroutines spawned during a Send (replacement + // spawns, discard-shutdowns). Shutdown waits on it before draining the + // pool / closing the runtime. + // + // Invariant: every pending.Add is either (a) made under d.mu after + // observing !closed, or (b) made by code that is itself holding a pending + // count (e.g. discardAsync called from inside Send). This keeps Add from + // racing Shutdown's Wait — if closed is already set, branch (a) skips the + // Add and falls back to a synchronous close; in branch (b) Shutdown is + // guaranteed to still be blocked at Wait on the caller's count. + pending sync.WaitGroup +} + +var _ dispatcher.Dispatcher = (*Dispatcher)(nil) + +// NewDispatcher creates a new WASM dispatcher. Compilation and pool +// initialisation happen in Start. +func NewDispatcher(cfg Config, log *zap.Logger) *Dispatcher { + return &Dispatcher{ + cfg: cfg, + log: log.Named("dispatcher_wasm"), + closedCh: make(chan struct{}), + } +} + +// tryBeginSend atomically checks the closed flag and increments pending. It +// returns false if Shutdown has begun (caller must abort with +// ErrDispatcherClosed); on true the caller MUST call pending.Done exactly +// once when finished. Holding a pending count across the entire Send keeps +// Shutdown's Wait blocked while the Send is mid-flight, which is what lets +// discardAsync inside Send safely Add to pending without racing Wait. +func (d *Dispatcher) tryBeginSend() bool { + d.mu.Lock() + defer d.mu.Unlock() + if d.closed { + return false + } + d.pending.Add(1) + return true +} + +// Start reads and compiles the .wasm file, sets up WASI host functions, and +// pre-warms the supervisor pool. +func (d *Dispatcher) Start(ctx context.Context) error { + // Pick up sandbox overrides from FUNCTION_WASM_* env vars (including + // FUNCTION_WASM_MODULE as an alternative to FUNCTION_COMMAND), then apply + // sensible defaults for any fields still at their zero values. + if err := d.cfg.applyEnv(); err != nil { + return fmt.Errorf("wasm: invalid environment configuration: %w", err) + } + d.cfg.applyDefaults() + + if d.cfg.ModulePath == "" { + return fmt.Errorf("wasm: ModulePath must be set (FUNCTION_COMMAND or FUNCTION_WASM_MODULE)") + } + + maxInstances := d.cfg.MaxInstances + if maxInstances <= 0 { + maxInstances = runtime.NumCPU() + } + + d.log.Info("starting wasm dispatcher", + zap.String("module", d.cfg.ModulePath), + zap.Int("max_instances", maxInstances), + zap.Uint32("max_memory_pages", d.cfg.MaxMemoryPages), + zap.Duration("timeout", d.cfg.Timeout), + ) + + // Read the .wasm bytes from disk. + wasmBytes, err := os.ReadFile(d.cfg.ModulePath) + if err != nil { + return fmt.Errorf("wasm: read module file %q: %w", d.cfg.ModulePath, err) + } + + // Build the runtime config with memory limit and context-done interruption. + rtCfg := wazero.NewRuntimeConfig(). + // WithCloseOnContextDone causes wazero to interrupt a running WASM module + // when the call context is cancelled or times out, preventing goroutine leaks. + WithCloseOnContextDone(true) + if d.cfg.MaxMemoryPages > 0 { + rtCfg = rtCfg.WithMemoryLimitPages(d.cfg.MaxMemoryPages) + } + + // Wire in on-disk compilation cache when configured. + if d.cfg.CompileCacheDir != "" { + cache, err := wazero.NewCompilationCacheWithDir(d.cfg.CompileCacheDir) + if err != nil { + d.log.Warn("failed to create wazero compilation cache, continuing without cache", + zap.String("dir", d.cfg.CompileCacheDir), + zap.Error(err)) + } else { + rtCfg = rtCfg.WithCompilationCache(cache) + d.cache = cache + d.log.Info("wazero compilation cache enabled", zap.String("dir", d.cfg.CompileCacheDir)) + } + } + + // Create a single wazero runtime shared by all instances. + rt := wazero.NewRuntimeWithConfig(ctx, rtCfg) + d.rt = rt + + // Instantiate WASI host functions. Most evaluation functions will need at + // least minimal WASI support (e.g. for memory allocation helpers compiled + // from C/Rust/TinyGo). + if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil { + _ = rt.Close(ctx) + return fmt.Errorf("wasm: instantiate wasi: %w", err) + } + + // Compile the module once; all instances share the compiled code. + compiled, err := rt.CompileModule(ctx, wasmBytes) + if err != nil { + _ = rt.Close(ctx) + return fmt.Errorf("wasm: compile module: %w", err) + } + d.compiled = compiled + + // Build a locked-down ModuleConfig: no filesystem, no env vars, no + // stdin/stdout/stderr, no args. Only allow nanosleep and wall/mono clocks + // which the Go runtime needs. + modCfg := wazero.NewModuleConfig(). + WithName(""). + WithSysNanosleep(). + WithSysWalltime(). + WithSysNanotime() + + // Filesystem: mount allowed paths read-only; no access by default. + fsCfg := wazero.NewFSConfig() + for _, p := range d.cfg.AllowedPaths { + fsCfg = fsCfg.WithReadOnlyDirMount(p, p) + } + modCfg = modCfg.WithFSConfig(fsCfg) + + // Env vars: expose only explicitly whitelisted variables. + for _, key := range d.cfg.AllowedEnv { + if val, ok := os.LookupEnv(key); ok { + modCfg = modCfg.WithEnv(key, val) + } + } + d.modCfg = modCfg + + // Build the pool. + d.pool = make(chan *wasmSupervisor, maxInstances) + + for i := 0; i < maxInstances; i++ { + sv := newWasmSupervisor(rt, compiled, modCfg, d.cfg.Timeout, d.log) + + if err := sv.Start(ctx); err != nil { + // Clean up already-started supervisors. + _ = drainBufferedPool(ctx, d.pool, d.log) + _ = rt.Close(ctx) + return fmt.Errorf("wasm: start instance %d: %w", i, err) + } + + d.pool <- sv + } + + d.log.Info("wasm dispatcher ready", zap.Int("instances", maxInstances)) + + return nil +} + +// Send acquires a supervisor from the pool, dispatches the request, and +// returns the supervisor to the pool. +func (d *Dispatcher) Send( + ctx context.Context, + method string, + data map[string]any, +) (map[string]any, error) { + if !d.tryBeginSend() { + return nil, ErrDispatcherClosed + } + defer d.pending.Done() + + // Acquire a supervisor, honouring the caller's context AND the shutdown + // signal so we never block forever on a drained pool. + var sv *wasmSupervisor + select { + case sv = <-d.pool: + case <-d.closedCh: + return nil, ErrDispatcherClosed + case <-ctx.Done(): + return nil, fmt.Errorf("wasm: acquire instance: %w", ctx.Err()) + } + + result, err := sv.Send(ctx, method, data) + + // Return the supervisor to the pool only if it is healthy. + // If the snapshot restore failed inside Send, sv.healthy is false and the + // supervisor's state is undefined — discard it and spawn a replacement so + // pool capacity is eventually restored. + if sv.IsHealthy() { + d.returnOrDiscard(sv) + } else { + d.log.Warn("wasm supervisor unhealthy after request — dropping from pool, spawning replacement") + d.discardAsync(sv) + d.spawnReplacementAsync() + } + + if err != nil { + return nil, fmt.Errorf("wasm: send: %w", err) + } + + return result, nil +} + +// returnOrDiscard puts a healthy supervisor back in the pool unless Shutdown +// has begun, in which case the supervisor is closed asynchronously so it does +// not leak past a drained pool. +// +// Must be called from a goroutine that already holds a pending count (i.e. +// from inside Send) so that the Add issued by discardAsync is guaranteed to +// happen before Shutdown's pending.Wait can return. +func (d *Dispatcher) returnOrDiscard(sv *wasmSupervisor) { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + d.discardAsync(sv) + return + } + // Push under the lock so it interleaves correctly with Shutdown's + // closed=true → drainPool sequence: either we push before closed is set + // (drainPool sees the supervisor) or we discard via the branch above. + d.pool <- sv + d.mu.Unlock() +} + +// discardAsync closes a discarded supervisor in the background and tracks it +// via the pending WaitGroup so Shutdown can wait for the close to complete +// before tearing down the runtime. +// +// Must be called from a goroutine that already holds a pending count +// (Send, via tryBeginSend). That invariant keeps Shutdown.pending.Wait +// blocked across this Add, eliminating the Add-after-Wait race. +func (d *Dispatcher) discardAsync(sv *wasmSupervisor) { + d.pending.Add(1) + go func() { + defer d.pending.Done() + _ = sv.Shutdown(context.Background()) + }() +} + +// spawnReplacementAsync kicks off spawnOne in a background goroutine, but only +// if the dispatcher is still open. If Shutdown has begun, no replacement is +// scheduled. Tracked via the pending WaitGroup. +func (d *Dispatcher) spawnReplacementAsync() { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return + } + d.pending.Add(1) + d.mu.Unlock() + go d.spawnOne() +} + +// spawnOne initialises a fresh wasmSupervisor and adds it to the pool. +// Called in a goroutine when an unhealthy supervisor is discarded so that +// pool capacity is eventually restored. Failures are logged but not fatal. +// +// If Shutdown begins while Start is running, the freshly initialised +// supervisor is closed immediately rather than inserted into the drained pool. +func (d *Dispatcher) spawnOne() { + defer d.pending.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + d.log.Info("wasm: initialising replacement supervisor") + sv := newWasmSupervisor(d.rt, d.compiled, d.modCfg, d.cfg.Timeout, d.log) + if err := sv.Start(ctx); err != nil { + d.log.Error("wasm: replacement supervisor init failed", zap.Error(err)) + return + } + + d.mu.Lock() + if d.closed { + d.mu.Unlock() + d.log.Info("wasm: replacement supervisor born during shutdown — closing immediately") + _ = sv.Shutdown(context.Background()) + return + } + d.pool <- sv + d.mu.Unlock() + d.log.Info("wasm: replacement supervisor ready") +} + +// Shutdown closes all module instances and the wazero runtime. Idempotent. +func (d *Dispatcher) Shutdown(ctx context.Context) error { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return nil + } + d.closed = true + // Close the channel under mu so the closed=true / close(closedCh) pair is + // atomic with respect to tryBeginSend: any Send that observes !closed has + // also pending.Add'd before Shutdown can reach pending.Wait. + close(d.closedCh) + d.mu.Unlock() + + d.log.Debug("shutting down wasm dispatcher") + + // Wait for in-flight Sends AND any background goroutines (replacement + // spawns / discard shutdowns) to finish so that no late-created supervisor + // lands in the pool after the drain below, no module is mid-Close while we + // close the runtime, and no Send is running against the wazero runtime + // when we tear it down. + d.pending.Wait() + + // Non-blocking drain: after pending.Wait, no spawn or returnOrDiscard + // goroutine will push to the pool, so we just close everything currently + // buffered. (drainPool's blocking-for-cap-items semantics would deadlock + // here when spawnOne took the closed-shortcut and never pushed.) + for { + select { + case sv := <-d.pool: + if err := sv.Shutdown(ctx); err != nil { + d.log.Warn("error shutting down pooled supervisor", zap.Error(err)) + } + default: + goto drained + } + } +drained: + + if d.rt != nil { + if err := d.rt.Close(ctx); err != nil { + return fmt.Errorf("wasm: close runtime: %w", err) + } + d.rt = nil + } + if d.cache != nil { + if err := d.cache.Close(ctx); err != nil { + return fmt.Errorf("wasm: close compilation cache: %w", err) + } + d.cache = nil + } + + return nil +} diff --git a/internal/execution/wasm/dispatcher_test.go b/internal/execution/wasm/dispatcher_test.go new file mode 100644 index 0000000..fb71f0a --- /dev/null +++ b/internal/execution/wasm/dispatcher_test.go @@ -0,0 +1,515 @@ +package wasm + +import ( + "context" + "errors" + "os" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "go.uber.org/zap" +) + +// echoModulePath returns the absolute path to the pre-compiled echo.wasm test +// fixture. The fixture is a minimal guest module that always returns +// {"ok":true} regardless of the request, which lets us test the host-side Go +// code (alloc call, memory write, evaluate call, length-prefix parsing, JSON +// unmarshal) without implementing a full language runtime in WAT. +func echoModulePath(t *testing.T) string { + t.Helper() + // __file__ is not available in Go, but runtime.Caller gives us the source + // file path so we can derive testdata/ relative to the test file. + _, filename, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller failed") + return filepath.Join(filepath.Dir(filename), "testdata", "echo.wasm") +} + +// newTestLogger returns a no-op zap logger suitable for unit tests. +func newTestLogger(t *testing.T) *zap.Logger { + t.Helper() + log, err := zap.NewDevelopment() + require.NoError(t, err) + return log +} + +// newEchoDispatcher creates a Dispatcher backed by the echo fixture and starts +// it. The caller is responsible for calling Shutdown. +func newEchoDispatcher(t *testing.T, maxInstances int) *Dispatcher { + t.Helper() + cfg := Config{ + ModulePath: echoModulePath(t), + MaxInstances: maxInstances, + Timeout: 5 * time.Second, + } + d := NewDispatcher(cfg, newTestLogger(t)) + require.NoError(t, d.Start(context.Background()), "dispatcher start") + return d +} + +// TestDispatcher_StartStop verifies that a Dispatcher can be started and shut +// down cleanly without any interaction in between. +func TestDispatcher_StartStop(t *testing.T) { + d := newEchoDispatcher(t, 1) + err := d.Shutdown(context.Background()) + assert.NoError(t, err) +} + +// TestDispatcher_StartStop_MultipleInstances verifies start/stop with the +// default pool size (NumCPU). +func TestDispatcher_StartStop_MultipleInstances(t *testing.T) { + d := newEchoDispatcher(t, runtime.NumCPU()) + err := d.Shutdown(context.Background()) + assert.NoError(t, err) +} + +// TestDispatcher_Send_BasicResponse sends a single request and checks that the +// echo module returns {"ok":true}. +func TestDispatcher_Send_BasicResponse(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + result, err := d.Send(context.Background(), "test", map[string]any{"hello": "world"}) + require.NoError(t, err) + require.NotNil(t, result) + + ok, exists := result["ok"] + assert.True(t, exists, "response should contain 'ok' key") + assert.Equal(t, true, ok, "response 'ok' should be true") +} + +// TestDispatcher_Send_EmptyParams verifies that Send works with nil params. +func TestDispatcher_Send_EmptyParams(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + result, err := d.Send(context.Background(), "noop", nil) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, true, result["ok"]) +} + +// TestDispatcher_Send_Concurrent sends 10 concurrent requests using a pool of +// 3 instances and verifies that all succeed. +func TestDispatcher_Send_Concurrent(t *testing.T) { + const ( + numWorkers = 10 + numRequests = 20 + poolSize = 3 + ) + + d := newEchoDispatcher(t, poolSize) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + type result struct { + res map[string]any + err error + } + + results := make([]result, numRequests) + var wg sync.WaitGroup + wg.Add(numRequests) + + sem := make(chan struct{}, numWorkers) + for i := range numRequests { + sem <- struct{}{} + go func(i int) { + defer wg.Done() + defer func() { <-sem }() + res, err := d.Send(context.Background(), "eval", map[string]any{"i": i}) + results[i] = result{res, err} + }(i) + } + + wg.Wait() + + for i, r := range results { + require.NoError(t, r.err, "request %d failed", i) + require.NotNil(t, r.res, "request %d returned nil result", i) + assert.Equal(t, true, r.res["ok"], "request %d: unexpected result", i) + } +} + +// TestDispatcher_Send_AfterShutdown checks that Send after Shutdown returns +// ErrDispatcherClosed immediately, rather than blocking on the drained pool +// until the caller's context expires. +func TestDispatcher_Send_AfterShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + + _, err := d.Send(context.Background(), "test", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed, "Send after Shutdown must return ErrDispatcherClosed") +} + +// TestDispatcher_Shutdown_Idempotent verifies that calling Shutdown twice does +// not return an error or double-close the runtime. +func TestDispatcher_Shutdown_Idempotent(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + require.NoError(t, d.Shutdown(context.Background()), "second Shutdown must be a no-op") +} + +// TestDispatcher_ReplacementDuringShutdown exercises the race where Send has +// just discarded an unhealthy supervisor and scheduled a replacement spawn +// while Shutdown begins. The replacement spawn must NOT insert a supervisor +// into a drained pool, and Shutdown must wait for the spawn goroutine to +// finish before closing the runtime (otherwise the late supervisor would +// reference a torn-down wazero.Runtime). +func TestDispatcher_ReplacementDuringShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + + // Consume the only supervisor in the pool to mimic an in-flight Send. + sv := <-d.pool + + // Simulate Send's unhealthy-path bookkeeping: schedule the discard close + // of the bad supervisor and the spawn of a replacement. + d.discardAsync(sv) + d.spawnReplacementAsync() + + // Shutdown races with the spawn. It must wait for pending background work + // (via d.pending.Wait) before draining the pool and closing the runtime. + require.NoError(t, d.Shutdown(context.Background())) + + // After Shutdown the pool must be empty: any replacement that finished + // initialising during the race window was closed by spawnOne's + // closed-guard rather than inserted. + assert.Equal(t, 0, len(d.pool), "drained pool must be empty after Shutdown") + + // Send after Shutdown returns ErrDispatcherClosed promptly. + _, err := d.Send(context.Background(), "test", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed) +} + +// TestDispatcher_Shutdown_WaitsForInFlightSends drives the original race the +// lifecycle patch is meant to fix: many concurrent Sends are issued while +// Shutdown runs partway through. Without the in-flight tracking, Shutdown +// could close the wazero runtime out from under a live Send (use-after-close), +// or returnOrDiscard/discardAsync could call pending.Add after Shutdown's +// pending.Wait already returned. Both surfaces are caught by -race or by an +// outright panic. +// +// Acceptance: every Send either succeeds or returns ErrDispatcherClosed, never +// any other error; Shutdown returns nil; no panic. +func TestDispatcher_Shutdown_WaitsForInFlightSends(t *testing.T) { + d := newEchoDispatcher(t, runtime.NumCPU()) + + const numWorkers = 128 + var ( + wg sync.WaitGroup + successes atomic.Int64 + closedExits atomic.Int64 + unexpected atomic.Int64 + ) + wg.Add(numWorkers) + + start := make(chan struct{}) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + <-start + for j := 0; j < 5; j++ { + _, err := d.Send(context.Background(), "eval", map[string]any{"j": j}) + switch { + case err == nil: + successes.Add(1) + case errors.Is(err, ErrDispatcherClosed): + closedExits.Add(1) + return // dispatcher is gone; stop hammering + default: + unexpected.Add(1) + t.Errorf("unexpected error: %v", err) + return + } + } + }() + } + + close(start) + // Give some Sends a chance to begin. + time.Sleep(5 * time.Millisecond) + + require.NoError(t, d.Shutdown(context.Background())) + wg.Wait() + + assert.Zero(t, unexpected.Load(), "no Send may return a non-closed error") + // Post-shutdown Send must return ErrDispatcherClosed promptly (not block). + postCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := d.Send(postCtx, "eval", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed) + t.Logf("successes=%d closed_exits=%d", successes.Load(), closedExits.Load()) +} + +// TestDispatcher_Shutdown_UnblocksBlockedSend covers the second race called +// out in the patch: a Send that passed tryBeginSend but finds the pool empty +// (all supervisors are in-use or have been drained by a racing Shutdown). +// Without selecting on closedCh, the Send would block on the empty pool until +// the caller's context expired. With the patch it must return +// ErrDispatcherClosed as soon as Shutdown begins. +func TestDispatcher_Shutdown_UnblocksBlockedSend(t *testing.T) { + d := newEchoDispatcher(t, 1) + // Empty the pool so a Send is forced to block on acquire. + sv := <-d.pool + + type sendResult struct { + err error + } + res := make(chan sendResult, 1) + go func() { + _, err := d.Send(context.Background(), "eval", nil) + res <- sendResult{err: err} + }() + + // Let Send reach the empty-pool select. + time.Sleep(50 * time.Millisecond) + + // Put sv back so the dispatcher's drain has something to clean up + // (otherwise Shutdown sees an empty pool, which is also fine). + d.pool <- sv + + require.NoError(t, d.Shutdown(context.Background())) + + select { + case r := <-res: + // Either the Send got the supervisor before Shutdown drained it + // (succeeded), or Shutdown's closedCh fired first. + if r.err != nil { + assert.ErrorIs(t, r.err, ErrDispatcherClosed) + } + case <-time.After(2 * time.Second): + t.Fatal("Send did not return after Shutdown — closedCh select missing") + } +} + +// TestDispatcher_SpawnReplacementAsync_NoopAfterShutdown asserts that calling +// spawnReplacementAsync on a closed dispatcher is a no-op: it must not +// increment pending and must not launch a goroutine that touches the closed +// runtime. +func TestDispatcher_SpawnReplacementAsync_NoopAfterShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + + // Should return immediately without scheduling work. + d.spawnReplacementAsync() + + // Wait briefly with a deadline — pending.Wait would block forever if the + // no-op guard regressed and a goroutine were leaked with a stale runtime. + done := make(chan struct{}) + go func() { + d.pending.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pending.Wait did not return — spawn goroutine leaked after Shutdown") + } +} + +// TestDispatcher_MissingModule checks that Start fails when ModulePath does +// not point to a valid file. +func TestDispatcher_MissingModule(t *testing.T) { + cfg := Config{ + ModulePath: "/nonexistent/path/module.wasm", + MaxInstances: 1, + } + d := NewDispatcher(cfg, newTestLogger(t)) + err := d.Start(context.Background()) + assert.Error(t, err, "Start with missing module should fail") +} + +// TestSupervisor_MemoryRestored sends two sequential requests through the same +// supervisor and verifies that both succeed with the same response. This +// exercises the snapshot/restore cycle: after the first evaluate the bump +// allocator's heap_top is advanced, but restoreSnapshot rewinds memory so the +// second call starts from the exact same state. +func TestSupervisor_MemoryRestored(t *testing.T) { + // Use a pool of exactly 1 so both sends use the same supervisor instance. + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + ctx := context.Background() + + r1, err := d.Send(ctx, "first", map[string]any{"seq": 1}) + require.NoError(t, err) + require.NotNil(t, r1) + + r2, err := d.Send(ctx, "second", map[string]any{"seq": 2}) + require.NoError(t, err) + require.NotNil(t, r2) + + // Both responses must be identical {"ok":true}. + assert.Equal(t, r1, r2, "responses must be equal, proving memory was restored between calls") + assert.Equal(t, true, r1["ok"]) + assert.Equal(t, true, r2["ok"]) +} + +// TestSupervisor_MemoryRestored_ManyTimes exercises many sequential calls +// through a single-instance pool to ensure the snapshot/restore cycle is +// stable over repeated invocations. +func TestSupervisor_MemoryRestored_ManyTimes(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + ctx := context.Background() + const iters = 50 + + for i := range iters { + res, err := d.Send(ctx, "loop", map[string]any{"i": i}) + require.NoError(t, err, "iteration %d", i) + assert.Equal(t, true, res["ok"], "iteration %d", i) + } +} + +// buildMissingImportModule constructs a valid WASM module that imports a host +// function Shimmy does not provide. Compilation succeeds, but instantiation +// fails inside wasmSupervisor.Start. +func buildMissingImportModule() []byte { + section := func(id byte, payload []byte) []byte { + out := []byte{id} + out = append(out, leb128Encode(uint32(len(payload)))...) + out = append(out, payload...) + return out + } + name := func(s string) []byte { + out := leb128Encode(uint32(len(s))) + out = append(out, []byte(s)...) + return out + } + + module := []byte{0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00} + // Type section: one function type () -> (). + module = append(module, section(1, []byte{0x01, 0x60, 0x00, 0x00})...) + // Import section: one function import env.missing with type index 0. + importPayload := []byte{0x01} + importPayload = append(importPayload, name("env")...) + importPayload = append(importPayload, name("missing")...) + importPayload = append(importPayload, 0x00, 0x00) // kind=func, typeidx=0 + module = append(module, section(2, importPayload)...) + return module +} + +// TestDispatcher_StartFailure_DoesNotBlock verifies that a startup failure while +// initialising the warm instance pool returns an error instead of blocking while +// trying to drain a not-yet-full pool. +func TestDispatcher_StartFailure_DoesNotBlock(t *testing.T) { + modulePath := filepath.Join(t.TempDir(), "missing-import.wasm") + require.NoError(t, os.WriteFile(modulePath, buildMissingImportModule(), 0o644)) + + d := NewDispatcher(Config{ + ModulePath: modulePath, + MaxInstances: 2, + Timeout: 5 * time.Second, + }, newTestLogger(t)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := d.Start(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "start instance") +} + +// TestSupervisor_Start_Idempotent verifies that calling Start twice on the +// same supervisor does not error (the second call is a no-op). +func TestSupervisor_Start_Idempotent(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + require.NoError(t, sv.Start(ctx)) + require.NoError(t, sv.Start(ctx), "second Start must be a no-op") + require.NoError(t, sv.Shutdown(ctx)) +} + +// TestSupervisor_Send_NotStarted checks that Send before Start returns an +// error. +func TestSupervisor_Send_NotStarted(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + // Do NOT call sv.Start. + + _, err := sv.Send(ctx, "test", nil) + assert.Error(t, err, "Send without Start should return an error") +} + +// TestSupervisor_Send_MemoryGrowDetected is the regression test for +// memory.grow snapshot isolation: if the guest expands linear memory during a +// request, the supervisor must (a) detect the growth, (b) zero the grown tail +// so the next request cannot read leaked guest data, (c) surface +// ErrMemoryGrew, and (d) mark itself unhealthy so the dispatcher discards it +// instead of returning it to the pool. +// +// The echo fixture itself never grows memory, so we simulate a request that +// did by growing the module's memory from host code (between Start and Send) +// and writing a recognisable poison pattern into the new pages. After Send +// runs, restoreSnapshot observes mem.Size() > snapshotSize and must trip the +// defensive path. +func TestSupervisor_Send_MemoryGrowDetected(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + require.NoError(t, sv.Start(ctx)) + t.Cleanup(func() { _ = sv.Shutdown(ctx) }) + + require.True(t, sv.IsHealthy(), "supervisor should be healthy after Start") + + // Capture the snapshot size, then grow memory by 1 page (64 KiB) and + // poison the new pages. This simulates a guest that called memory.grow + // during execution and wrote sensitive data into the new pages. + mem := sv.mod.Memory() + require.NotNil(t, mem) + origSize := mem.Size() + require.Equal(t, origSize, sv.snapshotSize, "snapshotSize must be recorded at Take time") + + prevPages, ok := mem.Grow(1) + require.True(t, ok, "memory.Grow must succeed (echo fixture has no max)") + require.Equal(t, origSize/(64*1024), prevPages) + + grownSize := mem.Size() + require.Greater(t, grownSize, origSize, "memory must have grown") + + poison := make([]byte, grownSize-origSize) + for i := range poison { + poison[i] = 0xAB + } + require.True(t, mem.Write(origSize, poison), "poison tail") + + // Issue a request. The echo guest doesn't itself grow memory, but Send's + // post-call restoreSnapshot will observe the host-injected growth and + // trip the defensive path. + _, err := sv.Send(ctx, "test", map[string]any{"hello": "world"}) + require.Error(t, err, "Send must return the restore error") + assert.ErrorIs(t, err, ErrMemoryGrew, "error must wrap ErrMemoryGrew") + + assert.False(t, sv.IsHealthy(), "supervisor must be marked unhealthy after grow detected") + + // The grown tail must have been zeroed so no leftover guest data remains + // in the (now-unhealthy but still-instantiated) module. + tail, readOK := mem.Read(origSize, grownSize-origSize) + require.True(t, readOK) + expected := make([]byte, grownSize-origSize) + assert.Equal(t, expected, []byte(tail), "tail must be zero-filled, not contain poison bytes") +} diff --git a/internal/execution/wasm/pool.go b/internal/execution/wasm/pool.go new file mode 100644 index 0000000..1237b52 --- /dev/null +++ b/internal/execution/wasm/pool.go @@ -0,0 +1,62 @@ +package wasm + +import ( + "context" + + "go.uber.org/zap" +) + +// poolItem is the interface satisfied by any item that can be shut down when +// draining a pool. +type poolItem interface { + Shutdown(ctx context.Context) error +} + +// drainPool receives up to cap(pool) items from the channel and calls Shutdown +// on each. This helper is only used when the caller knows the pool is full. +func drainPool[T poolItem](ctx context.Context, pool chan T, log *zap.Logger) error { + if pool == nil { + return nil + } + + var firstErr error + for i := 0; i < cap(pool); i++ { + select { + case item := <-pool: + if err := item.Shutdown(ctx); err != nil { + log.Error("error shutting down pool item", zap.Error(err)) + if firstErr == nil { + firstErr = err + } + } + case <-ctx.Done(): + log.Warn("drainPool: context cancelled, some items may not be shut down", + zap.Int("remaining", cap(pool)-i)) + return ctx.Err() + } + } + return firstErr +} + +// drainBufferedPool shuts down only items currently buffered in the channel. It +// is safe for startup-failure paths where the pool may be only partially filled. +func drainBufferedPool[T poolItem](ctx context.Context, pool chan T, log *zap.Logger) error { + if pool == nil { + return nil + } + + var firstErr error + for { + select { + case item := <-pool: + if err := item.Shutdown(ctx); err != nil { + log.Error("error shutting down pool item", zap.Error(err)) + if firstErr == nil { + firstErr = err + } + } + default: + return firstErr + } + } +} diff --git a/internal/execution/wasm/robustness_test.go b/internal/execution/wasm/robustness_test.go new file mode 100644 index 0000000..cb59593 --- /dev/null +++ b/internal/execution/wasm/robustness_test.go @@ -0,0 +1,161 @@ +package wasm + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func wasmULEB(v uint32) []byte { + var buf []byte + for { + b := byte(v & 0x7f) + v >>= 7 + if v != 0 { + b |= 0x80 + } + buf = append(buf, b) + if v == 0 { + break + } + } + return buf +} + +func wasmSection(id byte, payload []byte) []byte { + out := []byte{id} + out = append(out, wasmULEB(uint32(len(payload)))...) + out = append(out, payload...) + return out +} + +func wasmName(s string) []byte { + out := wasmULEB(uint32(len(s))) + out = append(out, []byte(s)...) + return out +} + +func malformedABIWasm(allocReturnsValue, evaluateReturnsValue bool) []byte { + module := []byte{0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00} + + // Types: alloc(i32) [-> i32], evaluate(i32, i32) [-> i32]. + var types []byte + types = append(types, 0x02) + types = append(types, 0x60, 0x01, 0x7f) + if allocReturnsValue { + types = append(types, 0x01, 0x7f) + } else { + types = append(types, 0x00) + } + types = append(types, 0x60, 0x02, 0x7f, 0x7f) + if evaluateReturnsValue { + types = append(types, 0x01, 0x7f) + } else { + types = append(types, 0x00) + } + module = append(module, wasmSection(1, types)...) + + // Two functions: alloc uses type 0; evaluate uses type 1. + module = append(module, wasmSection(3, []byte{0x02, 0x00, 0x01})...) + + // One memory page. + module = append(module, wasmSection(5, []byte{0x01, 0x00, 0x01})...) + + // Export memory, alloc, evaluate. + var exports []byte + exports = append(exports, 0x03) + exports = append(exports, wasmName("memory")...) + exports = append(exports, 0x02, 0x00) + exports = append(exports, wasmName("alloc")...) + exports = append(exports, 0x00, 0x00) + exports = append(exports, wasmName("evaluate")...) + exports = append(exports, 0x00, 0x01) + module = append(module, wasmSection(7, exports)...) + + // Code bodies. + var code []byte + code = append(code, 0x02) + allocBody := []byte{0x00} + if allocReturnsValue { + allocBody = append(allocBody, 0x41, 0x08) // i32.const 8 + } + allocBody = append(allocBody, 0x0b) // end + code = append(code, wasmULEB(uint32(len(allocBody)))...) + code = append(code, allocBody...) + + evaluateBody := []byte{0x00} + if evaluateReturnsValue { + evaluateBody = append(evaluateBody, 0x41, 0x08) // i32.const 8 + } + evaluateBody = append(evaluateBody, 0x0b) // end + code = append(code, wasmULEB(uint32(len(evaluateBody)))...) + code = append(code, evaluateBody...) + module = append(module, wasmSection(10, code)...) + + return module +} + +func writeTempWasm(t *testing.T, bytes []byte) string { + t.Helper() + path := filepath.Join(t.TempDir(), "eval.wasm") + require.NoError(t, os.WriteFile(path, bytes, 0o644)) + return path +} + +func TestDispatcher_Send_ReturnsErrorForAllocWithoutReturnValue(t *testing.T) { + path := writeTempWasm(t, malformedABIWasm(false, true)) + d := NewDispatcher(Config{ModulePath: path, MaxInstances: 1, Timeout: time.Second}, newTestLogger(t)) + require.NoError(t, d.Start(context.Background())) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + _, err := d.Send(context.Background(), "eval", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "alloc returned 0 values") +} + +func TestDispatcher_Send_ReturnsErrorForEvaluateWithoutReturnValue(t *testing.T) { + path := writeTempWasm(t, malformedABIWasm(true, false)) + d := NewDispatcher(Config{ModulePath: path, MaxInstances: 1, Timeout: time.Second}, newTestLogger(t)) + require.NoError(t, d.Start(context.Background())) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + _, err := d.Send(context.Background(), "eval", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "evaluate returned 0 values") +} + +func TestDispatcher_StartRejectsInvalidMaxMemoryPagesEnv(t *testing.T) { + t.Setenv("FUNCTION_WASM_MAX_MEMORY_PAGES", "not-a-number") + d := NewDispatcher(Config{ModulePath: echoModulePath(t), MaxInstances: 1}, newTestLogger(t)) + + err := d.Start(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "FUNCTION_WASM_MAX_MEMORY_PAGES") +} + +func TestDispatcher_ShutdownClosesCompilationCache(t *testing.T) { + cacheDir := t.TempDir() + d := NewDispatcher(Config{ + ModulePath: echoModulePath(t), + MaxInstances: 1, + Timeout: time.Second, + CompileCacheDir: cacheDir, + MaxMemoryPages: 256, + }, newTestLogger(t)) + require.NoError(t, d.Start(context.Background())) + require.NotNil(t, d.cache, "dispatcher should retain the compilation cache so Shutdown can close it") + + require.NoError(t, d.Shutdown(context.Background())) + assert.Nil(t, d.cache, "closed compilation cache should be released") + + _, err := os.ReadDir(cacheDir) + if err != nil && strings.Contains(err.Error(), "bad file descriptor") { + t.Fatalf("cache directory should remain readable after cache close: %v", err) + } +} diff --git a/internal/execution/wasm/snapshot.go b/internal/execution/wasm/snapshot.go new file mode 100644 index 0000000..b575091 --- /dev/null +++ b/internal/execution/wasm/snapshot.go @@ -0,0 +1,96 @@ +package wasm + +import ( + "fmt" + + "github.com/tetratelabs/wazero/api" +) + +// SnapshotStrategy abstracts how linear-memory snapshots are taken and +// restored. The default implementation (FullMemcpyStrategy) copies the entire +// memory region on every restore. Future strategies may track dirty pages to +// reduce restore cost for large modules. +// +// Contract (I-4 fix — document ordering and concurrency expectations): +// - Take must be called at least once before Restore. +// - Take may be called multiple times; each call overwrites the previous +// snapshot. +// - Calling Restore without a prior Take is a no-op (returns nil) but +// logically meaningless. +// - Implementations are NOT safe for concurrent calls to Take / Restore. +// The caller (wasmSupervisor) must serialise access. +type SnapshotStrategy interface { + // Take captures the current state of the WASM linear memory. + // It is called once after module initialisation. + Take(mem api.Memory) error + + // Restore writes the captured snapshot back into WASM linear memory. + // It is called after every request so the next request sees a clean state. + Restore(mem api.Memory) error + + // Close releases any resources held by the strategy. It is safe to call on a + // zero-value or never-initialised strategy. + Close() error +} + +// --------------------------------------------------------------------------- +// FullMemcpyStrategy +// --------------------------------------------------------------------------- + +// FullMemcpyStrategy is the always-available baseline: it copies the entire +// linear memory into a []byte on Take and writes it all back on Restore. +// Cost is O(total memory size) regardless of how many pages were actually +// written during the request. +type FullMemcpyStrategy struct { + snapshot []byte +} + +// NewFullMemcpyStrategy returns a ready-to-use FullMemcpyStrategy. +func NewFullMemcpyStrategy() *FullMemcpyStrategy { + return &FullMemcpyStrategy{} +} + +// Take implements SnapshotStrategy. +func (f *FullMemcpyStrategy) Take(mem api.Memory) error { + if mem == nil { + f.snapshot = nil + return nil + } + + size := mem.Size() + if size == 0 { + f.snapshot = nil + return nil + } + + buf, ok := mem.Read(0, size) + if !ok { + return fmt.Errorf("snapshot: could not read %d bytes of linear memory", size) + } + + // Make an owned copy — mem.Read may return a slice backed by the wazero + // memory buffer which could be modified by subsequent guest execution. + f.snapshot = make([]byte, len(buf)) + copy(f.snapshot, buf) + + return nil +} + +// Restore implements SnapshotStrategy. +func (f *FullMemcpyStrategy) Restore(mem api.Memory) error { + if f.snapshot == nil || mem == nil { + return nil + } + + if !mem.Write(0, f.snapshot) { + return fmt.Errorf("snapshot: failed to restore %d bytes", len(f.snapshot)) + } + + return nil +} + +// Close implements SnapshotStrategy. FullMemcpyStrategy holds no OS resources. +func (f *FullMemcpyStrategy) Close() error { + f.snapshot = nil + return nil +} diff --git a/internal/execution/wasm/snapshot_test.go b/internal/execution/wasm/snapshot_test.go new file mode 100644 index 0000000..1abe885 --- /dev/null +++ b/internal/execution/wasm/snapshot_test.go @@ -0,0 +1,258 @@ +//go:build !plan9 + +package wasm + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// leb128Encode encodes a uint32 as an unsigned LEB128 byte slice. +func leb128Encode(v uint32) []byte { + var buf []byte + for { + b := byte(v & 0x7f) + v >>= 7 + if v != 0 { + b |= 0x80 + } + buf = append(buf, b) + if v == 0 { + break + } + } + return buf +} + +// buildTestMemoryModule constructs a minimal WASM binary that declares exactly +// `pages` pages (64 KiB each) of linear memory. wazero's Module.Memory() +// returns the first memory regardless of whether it is exported, so no export +// section is needed. +// +// Binary layout (WASM spec §5): +// +// \0asm (magic) + version (1) + memory section +// +// This mirrors buildMinimalMemoryModule from snapshot_bench_test.go but +// accepts *testing.T so it can be used in unit tests. +func buildTestMemoryModule(t *testing.T, pages int) []byte { + t.Helper() + + // Memory section payload: count=1, limits type=0x00 (min only), min=pages + pagesLEB := leb128Encode(uint32(pages)) + memPayload := append([]byte{0x01, 0x00}, pagesLEB...) + + // Section: id=5 (memory), size=len(payload), payload + memSec := append([]byte{0x05}, append(leb128Encode(uint32(len(memPayload))), memPayload...)...) + + // Full module: magic + version + memory section + module := []byte{0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00} + module = append(module, memSec...) + return module +} + +// newTestWazeroMemory instantiates a minimal WASM module with the given number +// of 64 KiB pages and returns its api.Memory. The runtime and module are +// closed via t.Cleanup. +func newTestWazeroMemory(t *testing.T, pages int) api.Memory { + t.Helper() + ctx := context.Background() + + wasmBin := buildTestMemoryModule(t, pages) + + rt := wazero.NewRuntime(ctx) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + compiled, err := rt.CompileModule(ctx, wasmBin) + require.NoError(t, err, "compile minimal module") + t.Cleanup(func() { _ = compiled.Close(ctx) }) + + mod, err := rt.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName("")) + require.NoError(t, err, "instantiate minimal module") + t.Cleanup(func() { _ = mod.Close(ctx) }) + + mem := mod.Memory() + require.NotNil(t, mem, "module must have linear memory") + return mem +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_TakeRestoreRoundtrip +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_TakeRestoreRoundtrip verifies the core contract: +// after Take, mutating the memory and calling Restore brings it back to the +// snapshotted state. +func TestFullMemcpyStrategy_TakeRestoreRoundtrip(t *testing.T) { + mem := newTestWazeroMemory(t, 1) // 1 page = 64 KiB + + // Fill memory with a known pattern. + size := mem.Size() + pattern := make([]byte, size) + for i := range pattern { + pattern[i] = byte(i % 251) + } + require.True(t, mem.Write(0, pattern), "write initial pattern") + + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + // Take snapshot. + require.NoError(t, s.Take(mem)) + + // Overwrite memory with zeros (simulated guest write). + zeros := make([]byte, size) + require.True(t, mem.Write(0, zeros), "overwrite with zeros") + + after, ok := mem.Read(0, size) + require.True(t, ok) + require.Equal(t, zeros, []byte(after), "sanity: memory should be all-zeros now") + + // Restore and verify memory matches original pattern. + require.NoError(t, s.Restore(mem)) + + restored, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, pattern, []byte(restored), "Restore must return memory to snapshotted state") +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_TakeNilMemory +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_TakeNilMemory checks that Take(nil) is safe and +// results in a nil snapshot (no panic, no error). +func TestFullMemcpyStrategy_TakeNilMemory(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + require.NoError(t, s.Take(nil)) + assert.Nil(t, s.snapshot, "snapshot should be nil after Take(nil)") + + // A subsequent Restore(nil) must also be a no-op. + require.NoError(t, s.Restore(nil)) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_RestoreBeforeTake +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_RestoreBeforeTake verifies that calling Restore on a +// zero-value / never-initialised strategy is a no-op that does not modify +// memory or return an error. +func TestFullMemcpyStrategy_RestoreBeforeTake(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + mem := newTestWazeroMemory(t, 1) + size := mem.Size() + + // Fill with recognisable data. + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 97) + } + require.True(t, mem.Write(0, data), "write initial data") + + // Snapshot the state so we can compare after Restore. + before, ok := mem.Read(0, size) + require.True(t, ok) + beforeCopy := make([]byte, len(before)) + copy(beforeCopy, before) + + // Restore before any Take — must be a no-op (snapshot is nil). + require.NoError(t, s.Restore(mem)) + + after, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, beforeCopy, []byte(after), "Restore before Take must leave memory unchanged") +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_EmptyMemory +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_EmptyMemory checks that a zero-size case in snapshot +// logic produces a nil snapshot (size==0 branch). We test this by calling +// Take with nil (which mirrors the zero-size code path in the implementation: +// both nil and zero-size result in snapshot=nil). +func TestFullMemcpyStrategy_EmptyMemory(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + // Take(nil) exercises the "mem == nil" branch which sets snapshot=nil. + require.NoError(t, s.Take(nil)) + assert.Nil(t, s.snapshot, "snapshot must be nil when memory is nil") + + // Restore(nil) must be a no-op. + require.NoError(t, s.Restore(nil)) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_CloseIdempotent +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_CloseIdempotent verifies that Close can be called +// multiple times without panicking or returning an error. +func TestFullMemcpyStrategy_CloseIdempotent(t *testing.T) { + s := NewFullMemcpyStrategy() + + mem := newTestWazeroMemory(t, 1) + require.NoError(t, s.Take(mem)) + assert.NotNil(t, s.snapshot, "snapshot should be set after Take") + + // First Close should succeed and clear the snapshot. + require.NoError(t, s.Close()) + assert.Nil(t, s.snapshot, "snapshot should be nil after first Close") + + // Second Close must also be safe. + require.NoError(t, s.Close()) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_SnapshotIsOwnedCopy +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_SnapshotIsOwnedCopy confirms that the snapshot is an +// independent copy of the memory buffer, not an alias into wazero's backing +// store. If Take stored a slice backed by the same underlying array, a +// subsequent guest write would silently corrupt the snapshot. +func TestFullMemcpyStrategy_SnapshotIsOwnedCopy(t *testing.T) { + mem := newTestWazeroMemory(t, 1) + size := mem.Size() + + // Write distinct pattern. + pattern := make([]byte, size) + for i := range pattern { + pattern[i] = byte(i % 199) + } + require.True(t, mem.Write(0, pattern), "write pattern") + + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + require.NoError(t, s.Take(mem)) + + // Overwrite memory entirely with 0xFF. + corrupt := make([]byte, size) + for i := range corrupt { + corrupt[i] = 0xFF + } + require.True(t, mem.Write(0, corrupt)) + + // Restore: snapshot must be independent of the wazero buffer. + require.NoError(t, s.Restore(mem)) + + restored, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, pattern, []byte(restored), "snapshot must be independent copy of original data") +} diff --git a/internal/execution/wasm/supervisor.go b/internal/execution/wasm/supervisor.go new file mode 100644 index 0000000..02dbb80 --- /dev/null +++ b/internal/execution/wasm/supervisor.go @@ -0,0 +1,221 @@ +package wasm + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "go.uber.org/zap" +) + +// ErrMemoryGrew indicates that the guest expanded linear memory during a +// request beyond the size captured at snapshot time. wazero (and the WASM +// spec) does not allow shrinking linear memory, so the original snapshotted +// state cannot be fully reproduced and the supervisor must be discarded. +var ErrMemoryGrew = errors.New("wasm: linear memory grew beyond snapshotted size") + +// wasmSupervisor manages a single instantiated WASM module. After the module +// is initialised its linear memory is snapshotted; the snapshot is restored +// after every Send so that the next request sees a clean initial state. This +// gives cheap warm-start semantics without re-compiling the module. +type wasmSupervisor struct { + mu sync.Mutex + + runtime wazero.Runtime + compiled wazero.CompiledModule + modCfg wazero.ModuleConfig + + mod api.Module + adapter *wasmAdapter + + // strategy implements snapshot/restore. The generic backend intentionally + // uses the portable full-memory copy strategy; dirty-page optimisation is a + // separate future concern. + strategy SnapshotStrategy + + // healthy is true when the supervisor is in a known-good state and can be + // safely returned to the pool. It is set to false when restoreSnapshot fails, + // indicating the WASM module's memory state is undefined. + healthy bool + + // snapshotSize is the linear-memory size (in bytes) captured at Take time. + // restoreSnapshot compares this against the post-request memory size to + // detect memory.grow during execution — wazero cannot shrink memory, so + // any growth invalidates the snapshot and must mark the supervisor unhealthy. + snapshotSize uint32 + + timeout time.Duration + log *zap.Logger +} + +func newWasmSupervisor( + rt wazero.Runtime, + compiled wazero.CompiledModule, + modCfg wazero.ModuleConfig, + timeout time.Duration, + log *zap.Logger, +) *wasmSupervisor { + return &wasmSupervisor{ + runtime: rt, + compiled: compiled, + modCfg: modCfg, + timeout: timeout, + log: log.Named("supervisor_wasm"), + } +} + +// Start instantiates the compiled module, runs any WASI start function, then +// snapshots linear memory. +func (s *wasmSupervisor) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod != nil { + return nil + } + + s.log.Debug("instantiating wasm module") + + // Apply start functions on top of the provided (sandboxed) module config. + instCfg := s.modCfg.WithStartFunctions("_initialize", "_start") + + mod, err := s.runtime.InstantiateModule(ctx, s.compiled, instCfg) + if err != nil { + return fmt.Errorf("wasm: instantiate module: %w", err) + } + + s.mod = mod + s.adapter = newWasmAdapter(mod, s.log) + s.healthy = true + + // Snapshot linear memory so we can restore it before each request. + s.strategy = NewFullMemcpyStrategy() + if err := s.takeSnapshot(); err != nil { + _ = s.strategy.Close() + _ = mod.Close(ctx) + s.mod = nil + return fmt.Errorf("wasm: snapshot memory: %w", err) + } + + memSize := uint32(0) + if m := s.mod.Memory(); m != nil { + memSize = m.Size() + } + s.log.Debug("wasm module ready", + zap.Uint32("snapshot_bytes", memSize), + zap.String("strategy", fmt.Sprintf("%T", s.strategy)), + ) + + return nil +} + +// Send calls the guest's evaluate function, then restores linear memory from +// the snapshot so the next request starts from a clean state. +func (s *wasmSupervisor) Send( + ctx context.Context, + method string, + data map[string]any, +) (map[string]any, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod == nil || s.adapter == nil { + return nil, fmt.Errorf("wasm: supervisor not started") + } + + result, err := s.adapter.send(ctx, method, data, s.timeout) + + // Restore memory snapshot to keep state clean for the next request. + // If restore fails, mark the supervisor unhealthy so the dispatcher + // discards it rather than returning it to the pool with undefined state. + if restoreErr := s.restoreSnapshot(); restoreErr != nil { + s.log.Error("failed to restore memory snapshot — marking supervisor unhealthy", zap.Error(restoreErr)) + s.healthy = false + if err == nil { + err = fmt.Errorf("wasm: restore snapshot: %w", restoreErr) + } + } + + return result, err +} + +// IsHealthy reports whether the supervisor is in a known-good state. +// Safe to call without holding s.mu (acquires the lock internally). (I-3 fix) +func (s *wasmSupervisor) IsHealthy() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.healthy +} + +// Shutdown closes the module instance and releases resources. +func (s *wasmSupervisor) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod == nil { + return nil + } + + s.log.Debug("shutting down wasm module instance") + + if err := s.mod.Close(ctx); err != nil { + return fmt.Errorf("wasm: close module: %w", err) + } + + s.mod = nil + s.adapter = nil + + if err := s.strategy.Close(); err != nil { + s.log.Warn("failed to close snapshot strategy", zap.Error(err)) + } + + return nil +} + +// takeSnapshot captures the guest's linear memory via the active strategy and +// records the memory size so restoreSnapshot can detect post-snapshot growth. +// Must be called with s.mu held. +func (s *wasmSupervisor) takeSnapshot() error { + mem := s.mod.Memory() + if mem == nil { + s.snapshotSize = 0 + return nil + } + if err := s.strategy.Take(mem); err != nil { + return err + } + s.snapshotSize = mem.Size() + return nil +} + +// restoreSnapshot restores the guest's linear memory from the last snapshot +// via the active strategy. If the guest grew memory during the request +// (memory.grow), it zero-fills the tail beyond the snapshotted size to prevent +// leaking guest data into the next request and returns ErrMemoryGrew so the +// caller (Send) marks the supervisor unhealthy and discards it. Must be called +// with s.mu held. +func (s *wasmSupervisor) restoreSnapshot() error { + if s.mod == nil { + return nil + } + mem := s.mod.Memory() + if mem == nil { + return nil + } + if err := s.strategy.Restore(mem); err != nil { + return err + } + if cur := mem.Size(); cur > s.snapshotSize { + tail := cur - s.snapshotSize + zeros := make([]byte, tail) + if !mem.Write(s.snapshotSize, zeros) { + return fmt.Errorf("wasm: memory grew by %d bytes; zero-fill failed: %w", tail, ErrMemoryGrew) + } + return fmt.Errorf("wasm: memory grew by %d bytes (tail zero-filled): %w", tail, ErrMemoryGrew) + } + return nil +} diff --git a/internal/execution/wasm/testdata/echo.wasm b/internal/execution/wasm/testdata/echo.wasm new file mode 100644 index 0000000..4e5072d Binary files /dev/null and b/internal/execution/wasm/testdata/echo.wasm differ diff --git a/internal/execution/wasm/testdata/echo.wat b/internal/execution/wasm/testdata/echo.wat new file mode 100644 index 0000000..e49d21b --- /dev/null +++ b/internal/execution/wasm/testdata/echo.wat @@ -0,0 +1,66 @@ +;; echo.wat — minimal guest ABI fixture for wasm package tests. +;; +;; Implements: +;; alloc(size i32) i32 — bump allocator; heap pointer stored at mem[0..3] +;; evaluate(req_ptr i32, req_len i32) i32 +;; — ignores input; always returns fixed response {"ok":true} +;; as a length-prefixed blob: [4-byte LE uint32 len][JSON bytes] +;; +;; The compiled binary (echo.wasm) was generated from this source. +;; {"ok":true} is 11 bytes: 7b 22 6f 6b 22 3a 74 72 75 65 7d +;; +;; Design note: the heap pointer is stored IN linear memory (offset 0, 4 bytes) +;; rather than in a WASM global. This means the snapshot/restore mechanism +;; (which copies linear memory) correctly resets the allocator state between +;; requests. If a global were used, snapshot/restore would not reset it and +;; the heap pointer would keep advancing across requests. +(module + (memory (export "memory") 1) + + ;; mem[0..3]: heap pointer (i32, LE), initialized to 4 + ;; (offset 0..3 reserved for the pointer itself, so allocations start at 4) + (data (i32.const 0) "\04\00\00\00") + + ;; alloc(size i32) i32 + (func (export "alloc") (param $size i32) (result i32) + (local $ptr i32) + ;; ptr = i32.load(mem[0]) + (local.set $ptr (i32.load (i32.const 0))) + ;; mem[0] = ptr + size + (i32.store (i32.const 0) (i32.add (local.get $ptr) (local.get $size))) + (local.get $ptr) + ) + + ;; evaluate(req_ptr i32, req_len i32) i32 + ;; Returns pointer P where: + ;; mem[P .. P+4) = little-endian uint32 length (11) + ;; mem[P+4 .. P+15) = {"ok":true} + (func (export "evaluate") (param $req_ptr i32) (param $req_len i32) (result i32) + (local $resp_ptr i32) + ;; resp_ptr = i32.load(mem[0]) + (local.set $resp_ptr (i32.load (i32.const 0))) + ;; mem[0] = resp_ptr + 15 (4 bytes length prefix + 11 bytes JSON) + (i32.store (i32.const 0) (i32.add (local.get $resp_ptr) (i32.const 15))) + + ;; Write little-endian length prefix: 11, 0, 0, 0 + (i32.store8 offset=0 (local.get $resp_ptr) (i32.const 11)) + (i32.store8 offset=1 (local.get $resp_ptr) (i32.const 0)) + (i32.store8 offset=2 (local.get $resp_ptr) (i32.const 0)) + (i32.store8 offset=3 (local.get $resp_ptr) (i32.const 0)) + + ;; Write {"ok":true} + (i32.store8 offset=4 (local.get $resp_ptr) (i32.const 0x7b)) ;; { + (i32.store8 offset=5 (local.get $resp_ptr) (i32.const 0x22)) ;; " + (i32.store8 offset=6 (local.get $resp_ptr) (i32.const 0x6f)) ;; o + (i32.store8 offset=7 (local.get $resp_ptr) (i32.const 0x6b)) ;; k + (i32.store8 offset=8 (local.get $resp_ptr) (i32.const 0x22)) ;; " + (i32.store8 offset=9 (local.get $resp_ptr) (i32.const 0x3a)) ;; : + (i32.store8 offset=10 (local.get $resp_ptr) (i32.const 0x74)) ;; t + (i32.store8 offset=11 (local.get $resp_ptr) (i32.const 0x72)) ;; r + (i32.store8 offset=12 (local.get $resp_ptr) (i32.const 0x75)) ;; u + (i32.store8 offset=13 (local.get $resp_ptr) (i32.const 0x65)) ;; e + (i32.store8 offset=14 (local.get $resp_ptr) (i32.const 0x7d)) ;; } + + (local.get $resp_ptr) + ) +) diff --git a/internal/execution/wasm/testhelpers_test.go b/internal/execution/wasm/testhelpers_test.go new file mode 100644 index 0000000..73161e5 --- /dev/null +++ b/internal/execution/wasm/testhelpers_test.go @@ -0,0 +1,46 @@ +package wasm + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" +) + +// echoWasmBytes reads the pre-compiled echo fixture from testdata/echo.wasm. +// The fixture is a minimal WASM module that: +// - exports a bump-allocator alloc(size i32) i32 +// - exports evaluate(req_ptr i32, req_len i32) i32 that always returns the +// fixed JSON {"ok":true} as a 4-byte LE length-prefixed blob +// +// The WAT source is kept alongside the binary at testdata/echo.wat for +// reference. The binary was generated using a pure-Go WASM assembler so that +// the test suite requires no external toolchain. +func echoWasmBytes(t *testing.T) []byte { + t.Helper() + path := echoModulePath(t) + b, err := os.ReadFile(path) + require.NoError(t, err, "read echo.wasm fixture") + return b +} + +// compileEchoModule creates a wazero runtime, wires up WASI host functions, +// and compiles the echo WASM bytes into a CompiledModule. The runtime must be +// closed by the caller. +func compileEchoModule(t *testing.T, ctx context.Context, wasmBytes []byte) (wazero.Runtime, wazero.CompiledModule) { + t.Helper() + + rt := wazero.NewRuntime(ctx) + _, err := wasi_snapshot_preview1.Instantiate(ctx, rt) + require.NoError(t, err, "instantiate WASI") + + compiled, err := rt.CompileModule(ctx, wasmBytes) + require.NoError(t, err, "compile echo module") + + t.Cleanup(func() { _ = compiled.Close(ctx) }) + + return rt, compiled +} diff --git a/scripts/benchmark-wasm-e2e.py b/scripts/benchmark-wasm-e2e.py new file mode 100755 index 0000000..4862fe7 --- /dev/null +++ b/scripts/benchmark-wasm-e2e.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""End-to-end benchmark for Shimmy's WASM execution path. + +The benchmark builds the demo WASM evaluator, starts a real Shimmy HTTP server +with FUNCTION_INTERFACE=wasm, and measures POST requests through the public +runtime endpoint. It intentionally covers several payload shapes instead of only +a tiny happy-path request. +""" + +from __future__ import annotations + +import argparse +import json +import os +import socket +import statistics +import subprocess +import sys +import time +import urllib.error +import urllib.request +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +ROOT = Path(__file__).resolve().parents[1] +BIN = ROOT / "bin" / "shimmy-wasm-e2e-bench" +DEMO_DIR = ROOT / "examples" / "demo-stateful" +WASM = DEMO_DIR / "eval.wasm" +LOG = ROOT / ".benchmark-wasm-e2e-server.log" + + +@dataclass(frozen=True) +class PayloadCase: + name: str + command: str + body: dict[str, Any] + description: str + expected_correct: bool | None = None + + +def payload_cases() -> list[PayloadCase]: + large_answer = "x" * (32 * 1024) + medium_preview = "preview-line\n" * 64 + cases = [ + {"answer": f"candidate-{i:02d}", "feedback": f"case {i:02d} feedback"} + for i in range(29) + ] + [{"answer": "target-case", "feedback": "matched final case"}] + + return [ + PayloadCase( + name="eval-short-correct", + command="eval", + body={"response": "42", "answer": "42", "params": {}}, + description="small correct eval request", + expected_correct=True, + ), + PayloadCase( + name="eval-short-incorrect", + command="eval", + body={"response": "41", "answer": "42", "params": {}}, + description="small incorrect eval request", + expected_correct=False, + ), + PayloadCase( + name="eval-large-response", + command="eval", + body={"response": large_answer, "answer": large_answer, "params": {}}, + description="large response/answer strings through HTTP + WASM memory", + expected_correct=True, + ), + PayloadCase( + name="eval-cases-heavy", + command="eval", + body={"response": "target-case", "answer": "canonical", "params": {"cases": cases}}, + description="incorrect eval plus host-side case matching that re-enters the evaluator", + expected_correct=False, + ), + PayloadCase( + name="preview-medium", + command="preview", + body={"response": medium_preview, "params": {"mode": "markdown"}}, + description="preview command with medium multiline content", + expected_correct=None, + ), + ] + + +def run(cmd: list[str], *, cwd: Path = ROOT, env: dict[str, str] | None = None) -> None: + print("$", " ".join(cmd), file=sys.stderr) + subprocess.run(cmd, cwd=str(cwd), env=env, check=True) + + +def choose_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return int(s.getsockname()[1]) + + +def build_artifacts() -> None: + BIN.parent.mkdir(parents=True, exist_ok=True) + run(["go", "build", "-trimpath", "-buildvcs=false", "-o", str(BIN), "."]) + env = os.environ.copy() + env.update({"GOOS": "wasip1", "GOARCH": "wasm"}) + run(["go", "build", "-buildmode=c-shared", "-o", str(WASM), "."], cwd=DEMO_DIR, env=env) + + +def start_server(port: int) -> subprocess.Popen[str]: + LOG.unlink(missing_ok=True) + env = os.environ.copy() + env.update( + { + "LOG_LEVEL": "error", + "FUNCTION_INTERFACE": "wasm", + "FUNCTION_WASM_MODULE": str(WASM), + "FUNCTION_MAX_PROCS": "1", + "FUNCTION_TIMEOUT": "10s", + } + ) + log_file = LOG.open("w", encoding="utf-8") + proc = subprocess.Popen( + [str(BIN), "serve", "--host", "127.0.0.1", "--port", str(port)], + cwd=str(ROOT), + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + # Keep the file object alive via the process object for the server lifetime. + proc._shimmy_log_file = log_file # type: ignore[attr-defined] + return proc + + +def stop_server(proc: subprocess.Popen[str]) -> None: + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=3) + log_file = getattr(proc, "_shimmy_log_file", None) + if log_file is not None: + log_file.close() + + +def wait_ready(base_url: str, proc: subprocess.Popen[str]) -> None: + deadline = time.time() + 15 + last_error: Exception | None = None + while time.time() < deadline: + if proc.poll() is not None: + raise RuntimeError(f"server exited early with code {proc.returncode}; log: {LOG.read_text(encoding='utf-8', errors='replace')}") + try: + with urllib.request.urlopen(f"{base_url}/health", timeout=0.5) as resp: + if resp.status == 200: + return + except Exception as exc: # noqa: BLE001 - readiness retry loop + last_error = exc + time.sleep(0.1) + raise RuntimeError(f"server did not become ready: {last_error}; log: {LOG.read_text(encoding='utf-8', errors='replace')}") + + +def post_json(base_url: str, case: PayloadCase) -> dict[str, Any]: + body = json.dumps(case.body, separators=(",", ":")).encode("utf-8") + req = urllib.request.Request( + f"{base_url}/", + data=body, + method="POST", + headers={"Content-Type": "application/json", "Command": case.command}, + ) + try: + with urllib.request.urlopen(req, timeout=20) as resp: + raw = resp.read() + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {exc.code} for {case.name}: {detail}") from exc + return json.loads(raw) + + +def validate_response(case: PayloadCase, response: dict[str, Any]) -> None: + if "error" in response: + raise AssertionError(f"{case.name} returned error: {response['error']}") + result = response.get("result") + if not isinstance(result, dict): + raise AssertionError(f"{case.name} response missing object result: {response}") + if case.command == "eval": + if result.get("guest_invocation_count") != 1 or result.get("snapshot_isolation_ok") is not True: + raise AssertionError(f"{case.name} did not preserve WASM snapshot isolation: {result}") + if case.expected_correct is not None and result.get("is_correct") is not case.expected_correct: + raise AssertionError(f"{case.name} expected is_correct={case.expected_correct}: {result}") + elif case.command == "preview": + preview = result.get("preview") + if not isinstance(preview, dict) or "content" not in preview: + raise AssertionError(f"{case.name} response missing preview content: {response}") + + +def percentile(values: list[float], pct: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + index = (len(ordered) - 1) * pct + lo = int(index) + hi = min(lo + 1, len(ordered) - 1) + frac = index - lo + return ordered[lo] * (1 - frac) + ordered[hi] * frac + + +def bench_case(base_url: str, case: PayloadCase, iterations: int, warmup: int) -> dict[str, Any]: + for _ in range(warmup): + validate_response(case, post_json(base_url, case)) + + timings_ms: list[float] = [] + response_bytes = len(json.dumps(post_json(base_url, case), separators=(",", ":")).encode("utf-8")) + for _ in range(iterations): + start = time.perf_counter_ns() + response = post_json(base_url, case) + elapsed_ms = (time.perf_counter_ns() - start) / 1_000_000 + validate_response(case, response) + timings_ms.append(elapsed_ms) + + request_bytes = len(json.dumps(case.body, separators=(",", ":")).encode("utf-8")) + return { + "name": case.name, + "command": case.command, + "description": case.description, + "iterations": iterations, + "request_bytes": request_bytes, + "response_bytes": response_bytes, + "min_ms": min(timings_ms), + "mean_ms": statistics.fmean(timings_ms), + "p50_ms": percentile(timings_ms, 0.50), + "p95_ms": percentile(timings_ms, 0.95), + "max_ms": max(timings_ms), + } + + +def print_table(results: list[dict[str, Any]]) -> None: + print("\nWASM end-to-end benchmark (HTTP -> Shimmy -> WASM -> HTTP)") + print("payload cmd req B mean ms p50 ms p95 ms max ms") + print("-" * 78) + for row in results: + print( + f"{row['name']:<23} {row['command']:<8} {row['request_bytes']:>6} " + f"{row['mean_ms']:>9.2f} {row['p50_ms']:>8.2f} {row['p95_ms']:>8.2f} {row['max_ms']:>8.2f}" + ) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--iterations", type=int, default=25, help="measured requests per payload") + parser.add_argument("--warmup", type=int, default=3, help="warmup requests per payload") + parser.add_argument("--payload", action="append", help="payload name to run; may be repeated") + parser.add_argument("--json-output", type=Path, help="optional path for machine-readable results") + parser.add_argument("--skip-build", action="store_true", help="reuse existing binary and eval.wasm") + args = parser.parse_args(argv) + + if args.iterations <= 0: + parser.error("--iterations must be positive") + if args.warmup < 0: + parser.error("--warmup must be non-negative") + + cases = payload_cases() + if args.payload: + wanted = set(args.payload) + known = {case.name for case in cases} + unknown = sorted(wanted - known) + if unknown: + parser.error(f"unknown payload(s): {', '.join(unknown)}; known: {', '.join(sorted(known))}") + cases = [case for case in cases if case.name in wanted] + + if not args.skip_build: + build_artifacts() + + port = choose_port() + base_url = f"http://127.0.0.1:{port}" + proc = start_server(port) + try: + wait_ready(base_url, proc) + results = [bench_case(base_url, case, args.iterations, args.warmup) for case in cases] + finally: + stop_server(proc) + + print_table(results) + output = { + "base_url": base_url, + "iterations": args.iterations, + "warmup": args.warmup, + "wasm": str(WASM), + "results": results, + } + if args.json_output: + args.json_output.parent.mkdir(parents=True, exist_ok=True) + args.json_output.write_text(json.dumps(output, indent=2, sort_keys=True) + "\n", encoding="utf-8") + print(f"\nWrote JSON results to {args.json_output}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/demo-cpp-wasm.sh b/scripts/demo-cpp-wasm.sh new file mode 100755 index 0000000..1369478 --- /dev/null +++ b/scripts/demo-cpp-wasm.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PORT="${PORT:-}" +HOST="127.0.0.1" +if [[ -z "${PORT}" ]]; then + PORT="$(python3 - <<'PY' +import socket +s = socket.socket() +s.bind(('127.0.0.1', 0)) +print(s.getsockname()[1]) +s.close() +PY +)" +fi +BASE_URL="http://${HOST}:${PORT}" +BIN="${ROOT}/bin/shimmy-demo-cpp" +DEMO_DIR="${ROOT}/examples/demo-cpp-compare" +WASM="${DEMO_DIR}/eval.wasm" +LOG="${ROOT}/.demo-cpp-wasm-server.log" + +for cmd in go curl python3 zig file; do + if ! command -v "${cmd}" >/dev/null 2>&1; then + echo "error: ${cmd} is required" >&2 + exit 1 + fi +done + +if [[ ! -f "${DEMO_DIR}/evaluator.cpp" ]]; then + echo "error: missing ${DEMO_DIR}/evaluator.cpp" >&2 + exit 1 +fi + +echo "==> Building shimmy demo binary" +(cd "${ROOT}" && go build -trimpath -buildvcs=false -o "${BIN}" .) + +echo "==> Building C++ evaluator: examples/demo-cpp-compare -> eval.wasm" +zig c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o "${WASM}" \ + "${DEMO_DIR}/evaluator.cpp" + +file "${WASM}" + +rm -f "${LOG}" +server_pid="" +cleanup() { + if [[ -n "${server_pid}" ]] && kill -0 "${server_pid}" 2>/dev/null; then + kill "${server_pid}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${server_pid}" 2>/dev/null || return 0 + sleep 0.1 + done + kill -KILL "${server_pid}" 2>/dev/null || true + wait "${server_pid}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +echo "==> Starting shimmy on ${BASE_URL}" +( + cd "${ROOT}" + exec env \ + LOG_LEVEL=error \ + FUNCTION_INTERFACE=wasm \ + FUNCTION_WASM_MODULE="${WASM}" \ + FUNCTION_MAX_PROCS=1 \ + FUNCTION_TIMEOUT=5s \ + "${BIN}" serve --host "${HOST}" --port "${PORT}" +) >"${LOG}" 2>&1 & +server_pid="$!" + +for _ in {1..60}; do + if ! kill -0 "${server_pid}" 2>/dev/null; then + echo "server exited early; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 + fi + if curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + echo "server did not become ready; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 +fi + +request_eval() { + local response="$1" + local answer="$2" + curl -fsS \ + -X POST "${BASE_URL}/" \ + -H 'Content-Type: application/json' \ + -H 'Command: eval' \ + --data "{\"response\":\"${response}\",\"answer\":\"${answer}\",\"params\":{\"correct_response_feedback\":\"Correct!\",\"incorrect_response_feedback\":\"Try again.\"}}" +} + +echo "==> Request #1: correct answer" +resp1="$(request_eval 42 42)" +echo "${resp1}" | python3 -m json.tool + +echo "==> Request #2: wrong answer; C++ guest global state should still reset" +resp2="$(request_eval 41 42)" +echo "${resp2}" | python3 -m json.tool + +RESP1="${resp1}" RESP2="${resp2}" python3 - <<'PY' +import json, os, sys +r1 = json.loads(os.environ['RESP1'])['result'] +r2 = json.loads(os.environ['RESP2'])['result'] +checks = [ + (r1.get('is_correct') is True, 'request #1 should be correct'), + (r1.get('feedback') == 'Correct!', 'request #1 should use correct feedback'), + (r2.get('is_correct') is False, 'request #2 should be incorrect'), + (r2.get('feedback') == 'Try again.', 'request #2 should use incorrect feedback'), + (r1.get('guest_invocation_count') == 1, 'request #1 should see guest_invocation_count == 1'), + (r2.get('guest_invocation_count') == 1, 'request #2 should still see guest_invocation_count == 1'), + (r1.get('snapshot_isolation_ok') is True, 'request #1 snapshot flag should be true'), + (r2.get('snapshot_isolation_ok') is True, 'request #2 snapshot flag should be true'), +] +failed = [msg for ok, msg in checks if not ok] +if failed: + print('DEMO FAILED:', *failed, sep='\n- ', file=sys.stderr) + sys.exit(1) +print('\n✅ C++ WASM demo passed: Shimmy ran a C++ evaluator compiled to WebAssembly and restored guest state after each request.') +PY + +printf '\nServer log: %s\n' "${LOG}" diff --git a/scripts/demo-lambda-feedback-file-worker.sh b/scripts/demo-lambda-feedback-file-worker.sh new file mode 100755 index 0000000..9922de4 --- /dev/null +++ b/scripts/demo-lambda-feedback-file-worker.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +ADAPTER_DIR="$ROOT_DIR/examples/lambda-feedback-adapter" +FIXTURE_ROOT="$ROOT_DIR/examples/lambda-feedback-fixtures/boilerplate-python" +WORKER="$ADAPTER_DIR/lf_file_worker.py" +PYTHON_BIN="${PYTHON_BIN:-$(command -v python3)}" + +usage() { + cat <<'USAGE' +Usage: scripts/demo-lambda-feedback-file-worker.sh [direct|http|all] + +Runs a Lambda Feedback-compatible Python evaluator through Shimmy's existing +file worker interface. No production runtime server is hidden in the toolkit shim: +the worker reads Shimmy file-IO JSON, calls module:function, and writes Shimmy's +schema-compatible response envelope. +USAGE +} + +run_direct() { + local tmp req res + tmp="$(mktemp -d)" + trap 'rm -rf "$tmp"' RETURN + req="$tmp/request.json" + res="$tmp/response.json" + cat >"$req" </dev/null + cat "$res" +} + +run_http() { + local port server_pid response preview + if [[ -n "${PORT:-}" ]]; then + port="$PORT" + else + port="$($PYTHON_BIN -c 'import socket; s=socket.socket(); s.bind(("127.0.0.1", 0)); print(s.getsockname()[1]); s.close()')" + fi + ( + cd "$ROOT_DIR" + exec go run . \ + --log-level error \ + --interface file \ + --command "$PYTHON_BIN" \ + --arg "$WORKER" \ + --env "FUNCTION_LF_ROOT=$FIXTURE_ROOT" \ + --env "FUNCTION_LF_ENTRYPOINT=evaluation_function.main:evaluation_function" \ + --env "FUNCTION_LF_PREVIEW_ENTRYPOINT=evaluation_function.main:preview_function" \ + serve \ + --host 127.0.0.1 \ + --port "$port" + ) & + server_pid=$! + trap 'kill "$server_pid" 2>/dev/null || true' RETURN + + for _ in $(seq 1 80); do + if curl -fsS "http://127.0.0.1:$port/health" >/dev/null 2>&1; then + break + fi + sleep 0.25 + done + + response="$(curl -fsS \ + -H 'content-type: application/json' \ + -d '{"response":"42","answer":"42","params":{}}' \ + "http://127.0.0.1:$port/")" + preview="$(curl -fsS \ + -H 'content-type: application/json' \ + -H 'command: preview' \ + -d '{"response":"draft","params":{"expected":"42"}}' \ + "http://127.0.0.1:$port/")" + + printf '%s\n%s\n' "$response" "$preview" +} + +mode="${1:-all}" +case "$mode" in + direct) run_direct ;; + http) run_http ;; + all) run_direct; run_http ;; + -h|--help|help) usage ;; + *) usage >&2; exit 2 ;; +esac diff --git a/scripts/demo-lambda-feedback-fixtures.sh b/scripts/demo-lambda-feedback-fixtures.sh new file mode 100755 index 0000000..d5e5011 --- /dev/null +++ b/scripts/demo-lambda-feedback-fixtures.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +RUNNER="${ROOT}/examples/lambda-feedback-adapter/run_lf_eval.py" +FIXTURES="${ROOT}/examples/lambda-feedback-fixtures" +PYTHON="${PYTHON:-python3}" + +usage() { + cat <<'EOF' +Usage: scripts/demo-lambda-feedback-fixtures.sh [list|local|all] + +Runs small Lambda Feedback compatibility fixtures through the backend-independent +local adapter. These demos validate evaluator package shapes before wiring the +same adapter into Pyodide or a Python reactor runtime. +EOF +} + +case "${1:-all}" in + list) + find "${FIXTURES}" -mindepth 1 -maxdepth 1 -type d -print | sort | sed "s#${FIXTURES}/##" + ;; + local|all) + echo "==> Boilerplate eval fixture" + "${PYTHON}" "${RUNNER}" \ + --root "${FIXTURES}/boilerplate-python" \ + --entrypoint evaluation_function.main:evaluation_function \ + --method eval \ + --response " 42 " \ + --answer "42" \ + --params-json '{}' + + echo "==> Boilerplate preview fixture" + "${PYTHON}" "${RUNNER}" \ + --root "${FIXTURES}/boilerplate-python" \ + --entrypoint evaluation_function.main:preview_function \ + --method preview \ + --response "draft" \ + --answer "answer" \ + --params-json '{"expected":"answer"}' + + echo "==> Relative import + two-argument preview fixture" + "${PYTHON}" "${RUNNER}" \ + --root "${FIXTURES}/relative-preview" \ + --entrypoint evaluation_function.evaluation:preview_function \ + --method preview \ + --response " Foo " \ + --params-json '{"mode":"set"}' + + echo "==> Relative import set-compare eval fixture" + "${PYTHON}" "${RUNNER}" \ + --root "${FIXTURES}/relative-preview" \ + --entrypoint evaluation_function.evaluation:evaluation_function \ + --method eval \ + --response "a,b" \ + --answer "b, a" \ + --params-json '{"mode":"set"}' + + echo "✅ Lambda Feedback local compatibility fixtures passed" + ;; + -h|--help|help) + usage + ;; + *) + usage >&2 + exit 2 + ;; +esac diff --git a/scripts/demo-wasm.sh b/scripts/demo-wasm.sh new file mode 100755 index 0000000..5ef0a99 --- /dev/null +++ b/scripts/demo-wasm.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PORT="${PORT:-}" +HOST="127.0.0.1" +if [[ -z "${PORT}" ]]; then + PORT="$(python3 - <<'PY' +import socket +s = socket.socket() +s.bind(('127.0.0.1', 0)) +print(s.getsockname()[1]) +s.close() +PY +)" +fi +BASE_URL="http://${HOST}:${PORT}" +BIN="${ROOT}/bin/shimmy-demo" +DEMO_DIR="${ROOT}/examples/demo-stateful" +WASM="${DEMO_DIR}/eval.wasm" +LOG="${ROOT}/.demo-wasm-server.log" + +for cmd in go curl python3; do + if ! command -v "${cmd}" >/dev/null 2>&1; then + echo "error: ${cmd} is required" >&2 + exit 1 + fi +done + +echo "==> Building shimmy demo binary" +(cd "${ROOT}" && go build -trimpath -buildvcs=false -o "${BIN}" .) + +echo "==> Building demo evaluator: examples/demo-stateful -> eval.wasm" +(cd "${DEMO_DIR}" && GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o "${WASM}" .) + +rm -f "${LOG}" +server_pid="" +cleanup() { + if [[ -n "${server_pid}" ]] && kill -0 "${server_pid}" 2>/dev/null; then + kill "${server_pid}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${server_pid}" 2>/dev/null || return 0 + sleep 0.1 + done + kill -KILL "${server_pid}" 2>/dev/null || true + wait "${server_pid}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +echo "==> Starting shimmy on ${BASE_URL}" +( + cd "${ROOT}" + exec env \ + LOG_LEVEL=error \ + FUNCTION_INTERFACE=wasm \ + FUNCTION_WASM_MODULE="${WASM}" \ + FUNCTION_MAX_PROCS=1 \ + FUNCTION_TIMEOUT=5s \ + "${BIN}" serve --host "${HOST}" --port "${PORT}" +) >"${LOG}" 2>&1 & +server_pid="$!" + +for _ in {1..60}; do + if ! kill -0 "${server_pid}" 2>/dev/null; then + echo "server exited early; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 + fi + if curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + echo "server did not become ready; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 +fi + +request_eval() { + local response="$1" + local answer="$2" + curl -fsS \ + -X POST "${BASE_URL}/" \ + -H 'Content-Type: application/json' \ + -H 'Command: eval' \ + --data "{\"response\":\"${response}\",\"answer\":\"${answer}\",\"params\":{}}" +} + +echo "==> Request #1: correct answer" +resp1="$(request_eval 42 42)" +echo "${resp1}" | python3 -m json.tool + +echo "==> Request #2: wrong answer; guest global state should still be reset" +resp2="$(request_eval 41 42)" +echo "${resp2}" | python3 -m json.tool + +RESP1="${resp1}" RESP2="${resp2}" python3 - <<'PY' +import json, os, sys +r1 = json.loads(os.environ['RESP1'])['result'] +r2 = json.loads(os.environ['RESP2'])['result'] +checks = [ + (r1.get('is_correct') is True, 'request #1 should be correct'), + (r2.get('is_correct') is False, 'request #2 should be incorrect'), + (r1.get('guest_invocation_count') == 1, 'request #1 should see guest_invocation_count == 1'), + (r2.get('guest_invocation_count') == 1, 'request #2 should still see guest_invocation_count == 1'), + (r1.get('snapshot_isolation_ok') is True, 'request #1 snapshot flag should be true'), + (r2.get('snapshot_isolation_ok') is True, 'request #2 snapshot flag should be true'), +] +failed = [msg for ok, msg in checks if not ok] +if failed: + print('DEMO FAILED:', *failed, sep='\n- ', file=sys.stderr) + sys.exit(1) +print('\n✅ Demo passed: two HTTP requests hit the same warm WASM evaluator, but guest global state was reset after each request.') +PY + +printf '\nServer log: %s\n' "${LOG}"