diff --git a/src/llm.hpp b/src/llm.hpp index ab673b208..cec8b1dcc 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -1403,7 +1403,8 @@ namespace LLM { out_layers, return_all_hidden_states); }; - return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, true), + input_ids.dim() + 1); } int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {