diff --git a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp index 06930b9..a1971eb 100644 --- a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp +++ b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp @@ -966,6 +966,94 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { return result; } +// --------------------------------------------------------------------------- +// Constrained generation primitives +// +// These let a Swift caller run the select-then-commit loop manually instead of +// using sampleNext: read the logits row, classify/choose a token under its own +// constraints, then commit the choice with acceptToken. The vocab queries are +// pure reads on the loaded vocab; getNextTokenLogits copies the live logits row +// (still resident from the last decode); acceptToken mirrors the KV-advancing +// decode used elsewhere so the shared context produces fresh logits afterward. +// --------------------------------------------------------------------------- + +int CotabbyInferenceEngine::getVocabSize() const { + // No vocab means no model loaded; report 0 so callers don't size buffers off a stale value. + if (!impl_ || !impl_->vocab) return 0; + return llama_vocab_n_tokens(impl_->vocab); +} + +bool CotabbyInferenceEngine::isEndOfGenerationToken(int32_t token) const { + if (!impl_ || !impl_->vocab) return false; + return llama_vocab_is_eog(impl_->vocab, token); +} + +int32_t CotabbyInferenceEngine::endOfSequenceToken() const { + // -1 is not a valid token id, so it doubles as the "no model" sentinel. + if (!impl_ || !impl_->vocab) return -1; + return llama_vocab_eos(impl_->vocab); +} + +int CotabbyInferenceEngine::getNextTokenLogits(int32_t sequence_id, + float* out, int out_capacity) const { + if (!impl_ || !impl_->vocab || !impl_->shared_ctx || !out) return 0; + + const SequenceState* seq = impl_->findSequence(sequence_id); + if (!seq) return 0; + + // Refuse to write past the caller's buffer; the full vocab-size row is all-or-nothing. + const int32_t n = llama_vocab_n_tokens(impl_->vocab); + if (n <= 0 || out_capacity < n) return 0; + + // -1 is the most recent decode's logits row, which is what the caller wants to inspect + // before choosing the next token. Null means no live logits (e.g. nothing decoded yet). + const float* logits = llama_get_logits_ith(impl_->shared_ctx, -1); + if (!logits) return 0; + + std::memcpy(out, logits, static_cast(n) * sizeof(float)); + return n; +} + +EngineStatus CotabbyInferenceEngine::acceptToken(int32_t sequence_id, int32_t token) { + if (!impl_ || !impl_->model || !impl_->shared_ctx) return EngineStatus::not_loaded; + + SequenceState* seq = impl_->findSequence(sequence_id); + if (!seq) return EngineStatus::error; + + if (seq->cancelled.load(std::memory_order_acquire)) { + return EngineStatus::cancelled; + } + + // Feed the chosen token to the sampler so repetition/penalty state matches what sampleNext + // would have produced; the caller selected the token externally but the sampler must still see it. + llama_sampler_accept(seq->sampler, token); + + // Serialize with the decoder thread so this manual decode never races an in-flight batch. + std::lock_guard lock(impl_->decode_mutex); + + // Single-token feedback decode, mirroring the KV-advancing step in sampleNext/processBatch: + // pos = current KV position, request logits so a fresh row is ready for getNextTokenLogits. + llama_batch batch = llama_batch_init(1, 0, 1); + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = static_cast(seq->kv_position_count); + batch.n_seq_id[0] = 1; + if (batch.seq_id && batch.seq_id[0]) { + batch.seq_id[0][0] = seq->seq_id; + } + batch.logits[0] = 1; + + int status = llama_decode(impl_->shared_ctx, batch); + llama_batch_free(batch); + + if (status != 0) { + return EngineStatus::error; + } + + seq->kv_position_count++; + return EngineStatus::ok; +} + // --------------------------------------------------------------------------- // KV cache management // --------------------------------------------------------------------------- diff --git a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h index 151d689..78ddc29 100644 --- a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h +++ b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h @@ -105,6 +105,36 @@ class CotabbyInferenceEngine { // Sampling SampleResult sampleNext(int32_t sequence_id); + // Constrained generation primitives + // + // These decouple token *selection* from the engine: a Swift caller can read the + // raw next-token logits, classify candidate tokens, pick one under its own + // constraints (grammar, vocabulary subset, etc.), and then commit it via + // `acceptToken` to advance the sequence. This is the manual alternative to + // `sampleNext`, which both selects and commits in one step. + + // Total vocabulary size, i.e. the number of float logits a row from + // `getNextTokenLogits` contains. 0 when no model is loaded. + int getVocabSize() const; + // Whether `token` is an end-of-generation marker (EOS or any other model + // stop token). Callers use this to terminate manual generation loops. + bool isEndOfGenerationToken(int32_t token) const; + // The model's end-of-sequence token id, or -1 when no model is loaded. + int32_t endOfSequenceToken() const; + // Copies the next-token logits row for `sequence_id` (the distribution + // produced by the most recent decode) into `out`. `out_capacity` must be at + // least `getVocabSize()`; exactly that many floats are written. Returns the + // number of floats written, or 0 on any error (null buffer, unknown + // sequence, too-small capacity, or no live logits). + int getNextTokenLogits(int32_t sequence_id, float* out, int out_capacity) const; + // Commits `token` as the chosen next token for `sequence_id`: feeds it to the + // sequence's sampler (so penalty/repetition state stays consistent) and + // feedback-decodes it to advance the KV cache by one position, leaving fresh + // logits at the new position for the next `getNextTokenLogits` call. This is + // the commit half of the manual select-then-commit loop. Guards not_loaded / + // cancelled and returns `error` if the decode fails. + EngineStatus acceptToken(int32_t sequence_id, int32_t token); + // KV cache management bool trimKV(int32_t sequence_id, int keep_positions); int getKVPositionCount(int32_t sequence_id) const;