Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,94 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) {
return result;
}

// ---------------------------------------------------------------------------
// Constrained generation primitives
//
// These let a Swift caller run the select-then-commit loop manually instead of
// using sampleNext: read the logits row, classify/choose a token under its own
// constraints, then commit the choice with acceptToken. The vocab queries are
// pure reads on the loaded vocab; getNextTokenLogits copies the live logits row
// (still resident from the last decode); acceptToken mirrors the KV-advancing
// decode used elsewhere so the shared context produces fresh logits afterward.
// ---------------------------------------------------------------------------

int CotabbyInferenceEngine::getVocabSize() const {
// No vocab means no model loaded; report 0 so callers don't size buffers off a stale value.
if (!impl_ || !impl_->vocab) return 0;
return llama_vocab_n_tokens(impl_->vocab);
}

bool CotabbyInferenceEngine::isEndOfGenerationToken(int32_t token) const {
if (!impl_ || !impl_->vocab) return false;
return llama_vocab_is_eog(impl_->vocab, token);
}

int32_t CotabbyInferenceEngine::endOfSequenceToken() const {
// -1 is not a valid token id, so it doubles as the "no model" sentinel.
if (!impl_ || !impl_->vocab) return -1;
return llama_vocab_eos(impl_->vocab);
}

int CotabbyInferenceEngine::getNextTokenLogits(int32_t sequence_id,
float* out, int out_capacity) const {
if (!impl_ || !impl_->vocab || !impl_->shared_ctx || !out) return 0;

const SequenceState* seq = impl_->findSequence(sequence_id);
if (!seq) return 0;

// Refuse to write past the caller's buffer; the full vocab-size row is all-or-nothing.
const int32_t n = llama_vocab_n_tokens(impl_->vocab);
if (n <= 0 || out_capacity < n) return 0;

// -1 is the most recent decode's logits row, which is what the caller wants to inspect
// before choosing the next token. Null means no live logits (e.g. nothing decoded yet).
const float* logits = llama_get_logits_ith(impl_->shared_ctx, -1);
if (!logits) return 0;

std::memcpy(out, logits, static_cast<size_t>(n) * sizeof(float));
return n;
}

EngineStatus CotabbyInferenceEngine::acceptToken(int32_t sequence_id, int32_t token) {
if (!impl_ || !impl_->model || !impl_->shared_ctx) return EngineStatus::not_loaded;

SequenceState* seq = impl_->findSequence(sequence_id);
if (!seq) return EngineStatus::error;

if (seq->cancelled.load(std::memory_order_acquire)) {
return EngineStatus::cancelled;
}

// Feed the chosen token to the sampler so repetition/penalty state matches what sampleNext
// would have produced; the caller selected the token externally but the sampler must still see it.
llama_sampler_accept(seq->sampler, token);

// Serialize with the decoder thread so this manual decode never races an in-flight batch.
std::lock_guard<std::mutex> lock(impl_->decode_mutex);

// Single-token feedback decode, mirroring the KV-advancing step in sampleNext/processBatch:
// pos = current KV position, request logits so a fresh row is ready for getNextTokenLogits.
llama_batch batch = llama_batch_init(1, 0, 1);
batch.n_tokens = 1;
batch.token[0] = token;
batch.pos[0] = static_cast<llama_pos>(seq->kv_position_count);
batch.n_seq_id[0] = 1;
if (batch.seq_id && batch.seq_id[0]) {
batch.seq_id[0][0] = seq->seq_id;
}
batch.logits[0] = 1;

int status = llama_decode(impl_->shared_ctx, batch);
llama_batch_free(batch);

if (status != 0) {
return EngineStatus::error;
}

seq->kv_position_count++;
return EngineStatus::ok;
}

// ---------------------------------------------------------------------------
// KV cache management
// ---------------------------------------------------------------------------
Expand Down
30 changes: 30 additions & 0 deletions Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ class CotabbyInferenceEngine {
// Sampling
SampleResult sampleNext(int32_t sequence_id);

// Constrained generation primitives
//
// These decouple token *selection* from the engine: a Swift caller can read the
// raw next-token logits, classify candidate tokens, pick one under its own
// constraints (grammar, vocabulary subset, etc.), and then commit it via
// `acceptToken` to advance the sequence. This is the manual alternative to
// `sampleNext`, which both selects and commits in one step.

// Total vocabulary size, i.e. the number of float logits a row from
// `getNextTokenLogits` contains. 0 when no model is loaded.
int getVocabSize() const;
// Whether `token` is an end-of-generation marker (EOS or any other model
// stop token). Callers use this to terminate manual generation loops.
bool isEndOfGenerationToken(int32_t token) const;
// The model's end-of-sequence token id, or -1 when no model is loaded.
int32_t endOfSequenceToken() const;
// Copies the next-token logits row for `sequence_id` (the distribution
// produced by the most recent decode) into `out`. `out_capacity` must be at
// least `getVocabSize()`; exactly that many floats are written. Returns the
// number of floats written, or 0 on any error (null buffer, unknown
// sequence, too-small capacity, or no live logits).
int getNextTokenLogits(int32_t sequence_id, float* out, int out_capacity) const;
// Commits `token` as the chosen next token for `sequence_id`: feeds it to the
// sequence's sampler (so penalty/repetition state stays consistent) and
// feedback-decodes it to advance the KV cache by one position, leaving fresh
// logits at the new position for the next `getNextTokenLogits` call. This is
// the commit half of the manual select-then-commit loop. Guards not_loaded /
// cancelled and returns `error` if the decode fails.
EngineStatus acceptToken(int32_t sequence_id, int32_t token);

// KV cache management
bool trimKV(int32_t sequence_id, int keep_positions);
int getKVPositionCount(int32_t sequence_id) const;
Expand Down
Loading