From b112437592e3c82eb1903536166b960544f7879b Mon Sep 17 00:00:00 2001 From: JanSkn Date: Sat, 6 Dec 2025 11:54:11 +0100 Subject: [PATCH 1/5] improve index building --- .../benchmark_numpy.py | 0 .../scripts => .deprecated}/dataloader.py | 0 .dockerignore | 1 + .gitignore | 1 + justfile | 10 +- src/backend/bindings/cpp_utils/__init__.py | 2 + src/backend/bindings/cpp_utils/__init__.pyi | 2 + src/backend/bindings/cpp_utils/_core.pyi | 10 +- src/backend/bindings/utils.cpp | 130 +- src/backend/pyproject.toml | 2 - .../index_builder/CMakeLists.txt | 13 +- .../index_builder/include/robin_hood.h | 2544 +++++++++++++++++ .../index_builder/index_builder.cpp | 1053 +++---- .../index_builder/merge_partial_indices.cpp | 341 +++ .../{msmarco.tsv => msmarco-docs.tsv} | 0 .../index_builder/test_data/msmarco.tsv.gz | Bin 219 -> 0 bytes .../search_engine/scripts/build-index.sh | 4 +- src/backend/uv.lock | 137 - .../test_index_builder/test_index_builder.py | 28 + 19 files changed, 3400 insertions(+), 878 deletions(-) rename {src/backend/search_engine/scripts => .deprecated}/benchmark_numpy.py (100%) rename {src/backend/search_engine/scripts => .deprecated}/dataloader.py (100%) create mode 100644 src/backend/search_engine/index_builder/include/robin_hood.h create mode 100644 src/backend/search_engine/index_builder/merge_partial_indices.cpp rename src/backend/search_engine/index_builder/test_data/{msmarco.tsv => msmarco-docs.tsv} (100%) delete mode 100644 src/backend/search_engine/index_builder/test_data/msmarco.tsv.gz diff --git a/src/backend/search_engine/scripts/benchmark_numpy.py b/.deprecated/benchmark_numpy.py similarity index 100% rename from src/backend/search_engine/scripts/benchmark_numpy.py rename to .deprecated/benchmark_numpy.py diff --git a/src/backend/search_engine/scripts/dataloader.py b/.deprecated/dataloader.py similarity index 100% rename from src/backend/search_engine/scripts/dataloader.py rename to .deprecated/dataloader.py diff --git a/.dockerignore b/.dockerignore index 4d1b5e3..ac62239 100644 --- a/.dockerignore +++ b/.dockerignore @@ -61,4 +61,5 @@ dist/ !src/backend/search_engine/index_builder/test_data/*.gz /src/backend/search_engine/index_builder/build/ +/src/backend/search_engine/index_builder/data/ /src/backend/search_engine/index/bin/ diff --git a/.gitignore b/.gitignore index 8bf195b..992c487 100644 --- a/.gitignore +++ b/.gitignore @@ -237,4 +237,5 @@ dist-ssr !src/backend/search_engine/index_builder/test_data/*.tsv !src/backend/search_engine/index_builder/test_data/*.gz +/src/backend/search_engine/index_builder/data/ /src/backend/search_engine/index/bin/ \ No newline at end of file diff --git a/justfile b/justfile index c61bfd1..caf8abf 100644 --- a/justfile +++ b/justfile @@ -18,10 +18,16 @@ local *uvicorn-args: chmod +x local.sh && \ ./local.sh {{uvicorn-args}} -build-index memory-limit="1024": +build-index memory-limit="1024" max-docs="-1": cd src/backend/search_engine/scripts/ && \ chmod +x build-index.sh && \ - ./build-index.sh {{memory-limit}} + ./build-index.sh {{memory-limit}} {{max-docs}} + +remove-index-files: + rm -rf src/backend/search_engine/index/bin + rm -rf src/backend/search_engine/index_builder/data/docstore + rm -rf src/backend/search_engine/index_builder/data/index + rm -rf src/backend/search_engine/index_builder/data/partial_indices # from installed package # caution, will override existing stubs diff --git a/src/backend/bindings/cpp_utils/__init__.py b/src/backend/bindings/cpp_utils/__init__.py index fb018e7..cb69ba1 100644 --- a/src/backend/bindings/cpp_utils/__init__.py +++ b/src/backend/bindings/cpp_utils/__init__.py @@ -2,6 +2,7 @@ from ._core import ( DocInfo, + Metadata, DocStore, InvertedIndex, PostingList, @@ -13,6 +14,7 @@ __all__ = [ "InvertedIndex", "PostingList", + "Metadata", "DocStore", "DocInfo", "normalize_search_query", diff --git a/src/backend/bindings/cpp_utils/__init__.pyi b/src/backend/bindings/cpp_utils/__init__.pyi index 1027700..41fc450 100644 --- a/src/backend/bindings/cpp_utils/__init__.pyi +++ b/src/backend/bindings/cpp_utils/__init__.pyi @@ -2,6 +2,7 @@ from __future__ import annotations from ._core import ( DocInfo, + Metadata, DocStore, InvertedIndex, PostingList, @@ -13,6 +14,7 @@ from ._core import ( __all__: list[str] = [ "InvertedIndex", "PostingList", + "Metadata", "DocStore", "DocInfo", "normalize_search_query", diff --git a/src/backend/bindings/cpp_utils/_core.pyi b/src/backend/bindings/cpp_utils/_core.pyi index bf6cd81..15822d2 100644 --- a/src/backend/bindings/cpp_utils/_core.pyi +++ b/src/backend/bindings/cpp_utils/_core.pyi @@ -3,7 +3,7 @@ CPP utils for search engine """ from __future__ import annotations import typing -__all__: list[str] = ['DocInfo', 'DocStore', 'IndexAccessor', 'InvertedIndex', 'PostingList', 'normalize_search_query', 'positional_intersect', 'find_docs'] +__all__: list[str] = ['DocInfo', 'DocStore', 'Metadata', 'IndexAccessor', 'InvertedIndex', 'PostingList', 'normalize_search_query', 'positional_intersect', 'find_docs'] class DocInfo: @typing.overload def __init__(self) -> None: @@ -17,6 +17,14 @@ class DocInfo: @property def url(self) -> str: ... +class Metadata: + @property + def num_docs(self) -> int: ... + @property + def avg_doc_length(self) -> float: ... + @property + def doc_lengths(self) -> dict[int, int]: ... + def get_doc_length(self, doc_id: int) -> int: ... class DocStore: def get(self, doc_id: int) -> DocInfo | None: ... diff --git a/src/backend/bindings/utils.cpp b/src/backend/bindings/utils.cpp index c655b88..a6618e8 100644 --- a/src/backend/bindings/utils.cpp +++ b/src/backend/bindings/utils.cpp @@ -11,7 +11,6 @@ namespace py = pybind11; -// NOTE: Snowball stemmer instance is not thread-safe struct SnowballStemmer { struct sb_stemmer* stemmer; SnowballStemmer() { @@ -78,8 +77,36 @@ std::vector normalize_search_query(const std::string& text) { return tokens; } +struct Metadata { + uint32_t num_docs = 0; + double avg_doc_length = 0.0; + std::unordered_map doc_lengths; + + void load(const std::string& path) { + std::ifstream in(path, std::ios::binary); + if (!in.is_open()) throw std::runtime_error("Cannot open metadata file"); + + in.read(reinterpret_cast(&num_docs), sizeof(num_docs)); + in.read(reinterpret_cast(&avg_doc_length), sizeof(avg_doc_length)); + + while (in.peek() != EOF) { + uint32_t doc_id, length; + if (!in.read(reinterpret_cast(&doc_id), sizeof(doc_id))) break; + if (!in.read(reinterpret_cast(&length), sizeof(length))) break; + doc_lengths[doc_id] = length; + } + } + + uint32_t get_doc_length(uint32_t doc_id) const { + auto it = doc_lengths.find(doc_id); + if (it == doc_lengths.end()) return 0; + return it->second; + } +}; + struct PostingList { std::vector postings; + uint32_t doc_frequency; std::unordered_map term_frequencies; std::unordered_map> positions; std::unordered_map skip_pointers; @@ -102,38 +129,27 @@ struct PostingList { } }; -PostingList read_posting_list(std::ifstream& in, uint64_t offset, bool with_skip_pointers = false) { +PostingList read_posting_list(std::ifstream& in, uint64_t offset, uint32_t docFreq) { PostingList pl; + pl.doc_frequency = docFreq; in.seekg(offset); - uint32_t count_docs; - in.read(reinterpret_cast(&count_docs), sizeof(count_docs)); - pl.postings.resize(count_docs); + pl.postings.resize(docFreq); - for (uint32_t i = 0; i < count_docs; i++) { - uint32_t doc_id, tf, pos_count; + for (uint32_t i = 0; i < docFreq; i++) { + uint32_t doc_id, pos_count; in.read(reinterpret_cast(&doc_id), sizeof(doc_id)); - in.read(reinterpret_cast(&tf), sizeof(tf)); in.read(reinterpret_cast(&pos_count), sizeof(pos_count)); pl.postings[i] = doc_id; - pl.term_frequencies[doc_id] = tf; + pl.term_frequencies[doc_id] = pos_count; std::vector positions(pos_count); in.read(reinterpret_cast(positions.data()), pos_count * sizeof(uint32_t)); pl.positions[doc_id] = std::move(positions); } - if (with_skip_pointers) { - uint32_t skip_count; - in.read(reinterpret_cast(&skip_count), sizeof(skip_count)); - for (uint32_t i = 0; i < skip_count; i++) { - uint32_t from_idx, to_idx; - in.read(reinterpret_cast(&from_idx), sizeof(from_idx)); - in.read(reinterpret_cast(&to_idx), sizeof(to_idx)); - pl.skip_pointers[from_idx] = to_idx; - } - } + pl.build_skip_pointers(); return pl; } @@ -150,49 +166,54 @@ struct DocInfo { class DocStore { private: - // files for disk access - mutable std::ifstream data_in; - mutable std::ifstream offset_in; + std::unordered_map offsets; + std::ifstream data_in; uint32_t total_docs; public: - DocStore() : total_docs(0) {} - - void open(const std::string& filename_base) { - data_in.open(filename_base + ".docstore", std::ios::binary); - offset_in.open(filename_base + ".docstore_offsets", std::ios::binary); + void open(const std::string& dir_name) { + data_in.open(dir_name + "/docstore.bin", std::ios::binary); + std::ifstream off(dir_name + "/docstore_offsets.bin", std::ios::binary); - if (!data_in || !offset_in) { - throw std::runtime_error("Could not open docstore files: " + filename_base); - } + if (!data_in || !off) + throw std::runtime_error("Could not open docstore"); - // first is number of total docs + // docCount at the beginning data_in.read(reinterpret_cast(&total_docs), sizeof(total_docs)); + + while (true) { + uint32_t id; + uint64_t off64; + + if (!off.read(reinterpret_cast(&id), sizeof(id))) break; + if (!off.read(reinterpret_cast(&off64), sizeof(off64))) break; + + offsets[id] = off64; + } } std::optional get(uint32_t doc_id) { - if (doc_id >= total_docs) return std::nullopt; + auto it = offsets.find(doc_id); + if (it == offsets.end()) return std::nullopt; - // offset from offset file - uint64_t doc_offset; - offset_in.seekg(doc_id * sizeof(uint64_t)); - if (!offset_in.read(reinterpret_cast(&doc_offset), sizeof(doc_offset))) return std::nullopt; - - data_in.seekg(doc_offset); + uint64_t offset = it->second; + data_in.seekg(offset); uint32_t url_len; - if (!data_in.read(reinterpret_cast(&url_len), sizeof(url_len))) return std::nullopt; + data_in.read(reinterpret_cast(&url_len), sizeof(url_len)); + std::string url(url_len, '\0'); - if (!data_in.read(&url[0], url_len)) return std::nullopt; + data_in.read(url.data(), url_len); uint32_t title_len; - if (!data_in.read(reinterpret_cast(&title_len), sizeof(title_len))) return std::nullopt; + data_in.read(reinterpret_cast(&title_len), sizeof(title_len)); + std::string title(title_len, '\0'); - if (!data_in.read(&title[0], title_len)) return std::nullopt; + data_in.read(title.data(), title_len); return DocInfo{url, title}; } - + uint32_t size() const { return total_docs; } }; @@ -210,16 +231,18 @@ class IndexAccessor { class InvertedIndex { private: std::unordered_map term_to_offset; + std::unordered_map term_to_docfreq; std::ifstream postings_file; public: + Metadata metadata; DocStore doc_store; IndexAccessor index; InvertedIndex(const std::string& base_path) : index(this) { - std::ifstream index_file(base_path + "/inverted_index.index", std::ios::binary); + std::ifstream index_file(base_path + "/index.bin", std::ios::binary); while (true) { uint32_t term_len; if (!index_file.read(reinterpret_cast(&term_len), sizeof(term_len))) break; @@ -230,13 +253,18 @@ class InvertedIndex { uint64_t offset; if (!index_file.read(reinterpret_cast(&offset), sizeof(offset))) break; + uint32_t docFreq; + index_file.read(reinterpret_cast(&docFreq), sizeof(docFreq)); + term_to_offset[term] = offset; + term_to_docfreq[term] = docFreq; } - postings_file.open(base_path + "/inverted_index.postinglists", std::ios::binary); + postings_file.open(base_path + "/postinglists.bin", std::ios::binary); if (!postings_file.is_open()) throw std::runtime_error("Cannot open postinglists"); - doc_store.open(base_path + "/inverted_index"); + metadata.load(base_path + "/metadata.bin"); + doc_store.open(base_path); } friend class IndexAccessor; @@ -245,7 +273,8 @@ class InvertedIndex { std::optional IndexAccessor::get(const std::string& term) { auto it = parent->term_to_offset.find(term); if (it == parent->term_to_offset.end()) return std::nullopt; - PostingList pl = read_posting_list(parent->postings_file, it->second, true); + uint32_t docFreq = parent->term_to_docfreq.at(term); + PostingList pl = read_posting_list(parent->postings_file, it->second, docFreq); return pl; } @@ -517,6 +546,12 @@ PYBIND11_MODULE(_core, m) { .def_readonly("skip_pointers", &PostingList::skip_pointers) .def("build_skip_pointers", &PostingList::build_skip_pointers); + py::class_(m, "Metadata") + .def_readonly("num_docs", &Metadata::num_docs) + .def_readonly("avg_doc_length", &Metadata::avg_doc_length) + .def_readonly("doc_lengths", &Metadata::doc_lengths) + .def("get_doc_length", &Metadata::get_doc_length, py::arg("doc_id")); + py::class_(m, "DocStore") .def("get", &DocStore::get, py::arg("doc_id")); @@ -526,5 +561,6 @@ PYBIND11_MODULE(_core, m) { py::class_(m, "InvertedIndex") .def(py::init()) .def_readonly("index", &InvertedIndex::index) + .def_readonly("metadata", &InvertedIndex::metadata) .def_readonly("doc_store", &InvertedIndex::doc_store); } \ No newline at end of file diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 72ddf5d..78b0e92 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -9,8 +9,6 @@ dependencies = [ "pydantic>=2.12.3", "requests>=2.32.5", "tqdm>=4.67.1", - "typer>=0.20.0", - "pytest-cov>=7.0.0", ] [dependency-groups] diff --git a/src/backend/search_engine/index_builder/CMakeLists.txt b/src/backend/search_engine/index_builder/CMakeLists.txt index 19f310a..b2c02a6 100644 --- a/src/backend/search_engine/index_builder/CMakeLists.txt +++ b/src/backend/search_engine/index_builder/CMakeLists.txt @@ -4,14 +4,10 @@ project(Indexer_Builder) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -find_package(ZLIB REQUIRED) - find_path(STEMMER_INCLUDE_DIR libstemmer.h) find_library(STEMMER_LIBRARY stemmer) -set(SOURCES index_builder.cpp) - -add_executable(index_builder ${SOURCES}) +add_executable(index_builder index_builder.cpp) if (STEMMER_INCLUDE_DIR AND STEMMER_LIBRARY) target_include_directories(index_builder PRIVATE ${STEMMER_INCLUDE_DIR}) @@ -20,8 +16,13 @@ else() message(WARNING "libstemmer not found: STEMMER_INCLUDE_DIR=${STEMMER_INCLUDE_DIR}, STEMMER_LIBRARY=${STEMMER_LIBRARY}") endif() -target_link_libraries(index_builder PRIVATE ZLIB::ZLIB) +add_executable(merge_partial_indices merge_partial_indices.cpp) +if (STEMMER_INCLUDE_DIR AND STEMMER_LIBRARY) + target_include_directories(merge_partial_indices PRIVATE ${STEMMER_INCLUDE_DIR}) + target_link_libraries(merge_partial_indices PRIVATE ${STEMMER_LIBRARY}) +endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") target_link_libraries(index_builder PRIVATE stdc++fs) + target_link_libraries(merge_partial_indices PRIVATE stdc++fs) endif() diff --git a/src/backend/search_engine/index_builder/include/robin_hood.h b/src/backend/search_engine/index_builder/include/robin_hood.h new file mode 100644 index 0000000..b4e0fbc --- /dev/null +++ b/src/backend/search_engine/index_builder/include/robin_hood.h @@ -0,0 +1,2544 @@ +// ______ _____ ______ _________ +// ______________ ___ /_ ___(_)_______ ___ /_ ______ ______ ______ / +// __ ___/_ __ \__ __ \__ / __ __ \ __ __ \_ __ \_ __ \_ __ / +// _ / / /_/ /_ /_/ /_ / _ / / / _ / / // /_/ // /_/ // /_/ / +// /_/ \____/ /_.___/ /_/ /_/ /_/ ________/_/ /_/ \____/ \____/ \__,_/ +// _/_____/ +// +// Fast & memory efficient hashtable based on robin hood hashing for C++11/14/17/20 +// https://github.com/martinus/robin-hood-hashing +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2021 Martin Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROBIN_HOOD_H_INCLUDED +#define ROBIN_HOOD_H_INCLUDED + +// see https://semver.org/ +#define ROBIN_HOOD_VERSION_MAJOR 3 // for incompatible API changes +#define ROBIN_HOOD_VERSION_MINOR 11 // for adding functionality in a backwards-compatible manner +#define ROBIN_HOOD_VERSION_PATCH 5 // for backwards-compatible bug fixes + +#include +#include +#include +#include +#include +#include // only to support hash of smart pointers +#include +#include +#include +#include +#if __cplusplus >= 201703L +# include +#endif + +// #define ROBIN_HOOD_LOG_ENABLED +#ifdef ROBIN_HOOD_LOG_ENABLED +# include +# define ROBIN_HOOD_LOG(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; +#else +# define ROBIN_HOOD_LOG(x) +#endif + +// #define ROBIN_HOOD_TRACE_ENABLED +#ifdef ROBIN_HOOD_TRACE_ENABLED +# include +# define ROBIN_HOOD_TRACE(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; +#else +# define ROBIN_HOOD_TRACE(x) +#endif + +// #define ROBIN_HOOD_COUNT_ENABLED +#ifdef ROBIN_HOOD_COUNT_ENABLED +# include +# define ROBIN_HOOD_COUNT(x) ++counts().x; +namespace robin_hood { +struct Counts { + uint64_t shiftUp{}; + uint64_t shiftDown{}; +}; +inline std::ostream& operator<<(std::ostream& os, Counts const& c) { + return os << c.shiftUp << " shiftUp" << std::endl << c.shiftDown << " shiftDown" << std::endl; +} + +static Counts& counts() { + static Counts counts{}; + return counts; +} +} // namespace robin_hood +#else +# define ROBIN_HOOD_COUNT(x) +#endif + +// all non-argument macros should use this facility. See +// https://www.fluentcpp.com/2019/05/28/better-macros-better-flags/ +#define ROBIN_HOOD(x) ROBIN_HOOD_PRIVATE_DEFINITION_##x() + +// mark unused members with this macro +#define ROBIN_HOOD_UNUSED(identifier) + +// bitness +#if SIZE_MAX == UINT32_MAX +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITNESS() 32 +#elif SIZE_MAX == UINT64_MAX +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITNESS() 64 +#else +# error Unsupported bitness +#endif + +// endianess +#ifdef _MSC_VER +# define ROBIN_HOOD_PRIVATE_DEFINITION_LITTLE_ENDIAN() 1 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BIG_ENDIAN() 0 +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_LITTLE_ENDIAN() \ + (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +# define ROBIN_HOOD_PRIVATE_DEFINITION_BIG_ENDIAN() (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +#endif + +// inline +#ifdef _MSC_VER +# define ROBIN_HOOD_PRIVATE_DEFINITION_NOINLINE() __declspec(noinline) +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_NOINLINE() __attribute__((noinline)) +#endif + +// exceptions +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_EXCEPTIONS() 0 +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_EXCEPTIONS() 1 +#endif + +// count leading/trailing bits +#if !defined(ROBIN_HOOD_DISABLE_INTRINSICS) +# ifdef _MSC_VER +# if ROBIN_HOOD(BITNESS) == 32 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITSCANFORWARD() _BitScanForward +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITSCANFORWARD() _BitScanForward64 +# endif +# include +# pragma intrinsic(ROBIN_HOOD(BITSCANFORWARD)) +# define ROBIN_HOOD_COUNT_TRAILING_ZEROES(x) \ + [](size_t mask) noexcept -> int { \ + unsigned long index; \ + return ROBIN_HOOD(BITSCANFORWARD)(&index, mask) ? static_cast(index) \ + : ROBIN_HOOD(BITNESS); \ + }(x) +# else +# if ROBIN_HOOD(BITNESS) == 32 +# define ROBIN_HOOD_PRIVATE_DEFINITION_CTZ() __builtin_ctzl +# define ROBIN_HOOD_PRIVATE_DEFINITION_CLZ() __builtin_clzl +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_CTZ() __builtin_ctzll +# define ROBIN_HOOD_PRIVATE_DEFINITION_CLZ() __builtin_clzll +# endif +# define ROBIN_HOOD_COUNT_LEADING_ZEROES(x) ((x) ? ROBIN_HOOD(CLZ)(x) : ROBIN_HOOD(BITNESS)) +# define ROBIN_HOOD_COUNT_TRAILING_ZEROES(x) ((x) ? ROBIN_HOOD(CTZ)(x) : ROBIN_HOOD(BITNESS)) +# endif +#endif + +// fallthrough +#ifndef __has_cpp_attribute // For backwards compatibility +# define __has_cpp_attribute(x) 0 +#endif +#if __has_cpp_attribute(clang::fallthrough) +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() [[clang::fallthrough]] +#elif __has_cpp_attribute(gnu::fallthrough) +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() [[gnu::fallthrough]] +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() +#endif + +// likely/unlikely +#ifdef _MSC_VER +# define ROBIN_HOOD_LIKELY(condition) condition +# define ROBIN_HOOD_UNLIKELY(condition) condition +#else +# define ROBIN_HOOD_LIKELY(condition) __builtin_expect(condition, 1) +# define ROBIN_HOOD_UNLIKELY(condition) __builtin_expect(condition, 0) +#endif + +// detect if native wchar_t type is availiable in MSVC +#ifdef _MSC_VER +# ifdef _NATIVE_WCHAR_T_DEFINED +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 1 +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 0 +# endif +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 1 +#endif + +// detect if MSVC supports the pair(std::piecewise_construct_t,...) consructor being constexpr +#ifdef _MSC_VER +# if _MSC_VER <= 1900 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 1 +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +# endif +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +#endif + +// workaround missing "is_trivially_copyable" in g++ < 5.0 +// See https://stackoverflow.com/a/31798726/48181 +#if defined(__GNUC__) && __GNUC__ < 5 && !defined(__clang__) +# define ROBIN_HOOD_IS_TRIVIALLY_COPYABLE(...) __has_trivial_copy(__VA_ARGS__) +#else +# define ROBIN_HOOD_IS_TRIVIALLY_COPYABLE(...) std::is_trivially_copyable<__VA_ARGS__>::value +#endif + +// helpers for C++ versions, see https://gcc.gnu.org/onlinedocs/cpp/Standard-Predefined-Macros.html +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX() __cplusplus +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX98() 199711L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX11() 201103L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX14() 201402L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX17() 201703L + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX17) +# define ROBIN_HOOD_PRIVATE_DEFINITION_NODISCARD() [[nodiscard]] +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_NODISCARD() +#endif + +namespace robin_hood { + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX14) +# define ROBIN_HOOD_STD std +#else + +// c++11 compatibility layer +namespace ROBIN_HOOD_STD { +template +struct alignment_of + : std::integral_constant::type)> {}; + +template +class integer_sequence { +public: + using value_type = T; + static_assert(std::is_integral::value, "not integral type"); + static constexpr std::size_t size() noexcept { + return sizeof...(Ints); + } +}; +template +using index_sequence = integer_sequence; + +namespace detail_ { +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0 && Begin < End, "unexpected argument (Begin<0 || Begin<=End)"); + + template + struct IntSeqCombiner; + + template + struct IntSeqCombiner, integer_sequence> { + using TResult = integer_sequence; + }; + + using TResult = + typename IntSeqCombiner::TResult, + typename IntSeqImpl::TResult>::TResult; +}; + +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0, "unexpected argument (Begin<0)"); + using TResult = integer_sequence; +}; + +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0, "unexpected argument (Begin<0)"); + using TResult = integer_sequence; +}; +} // namespace detail_ + +template +using make_integer_sequence = typename detail_::IntSeqImpl::TResult; + +template +using make_index_sequence = make_integer_sequence; + +template +using index_sequence_for = make_index_sequence; + +} // namespace ROBIN_HOOD_STD + +#endif + +namespace detail { + +// make sure we static_cast to the correct type for hash_int +#if ROBIN_HOOD(BITNESS) == 64 +using SizeT = uint64_t; +#else +using SizeT = uint32_t; +#endif + +template +T rotr(T x, unsigned k) { + return (x >> k) | (x << (8U * sizeof(T) - k)); +} + +// This cast gets rid of warnings like "cast from 'uint8_t*' {aka 'unsigned char*'} to +// 'uint64_t*' {aka 'long unsigned int*'} increases required alignment of target type". Use with +// care! +template +inline T reinterpret_cast_no_cast_align_warning(void* ptr) noexcept { + return reinterpret_cast(ptr); +} + +template +inline T reinterpret_cast_no_cast_align_warning(void const* ptr) noexcept { + return reinterpret_cast(ptr); +} + +// make sure this is not inlined as it is slow and dramatically enlarges code, thus making other +// inlinings more difficult. Throws are also generally the slow path. +template +[[noreturn]] ROBIN_HOOD(NOINLINE) +#if ROBIN_HOOD(HAS_EXCEPTIONS) + void doThrow(Args&&... args) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay) + throw E(std::forward(args)...); +} +#else + void doThrow(Args&&... ROBIN_HOOD_UNUSED(args) /*unused*/) { + abort(); +} +#endif + +template +T* assertNotNull(T* t, Args&&... args) { + if (ROBIN_HOOD_UNLIKELY(nullptr == t)) { + doThrow(std::forward(args)...); + } + return t; +} + +template +inline T unaligned_load(void const* ptr) noexcept { + // using memcpy so we don't get into unaligned load problems. + // compiler should optimize this very well anyways. + T t; + std::memcpy(&t, ptr, sizeof(T)); + return t; +} + +// Allocates bulks of memory for objects of type T. This deallocates the memory in the destructor, +// and keeps a linked list of the allocated memory around. Overhead per allocation is the size of a +// pointer. +template +class BulkPoolAllocator { +public: + BulkPoolAllocator() noexcept = default; + + // does not copy anything, just creates a new allocator. + BulkPoolAllocator(const BulkPoolAllocator& ROBIN_HOOD_UNUSED(o) /*unused*/) noexcept + : mHead(nullptr) + , mListForFree(nullptr) {} + + BulkPoolAllocator(BulkPoolAllocator&& o) noexcept + : mHead(o.mHead) + , mListForFree(o.mListForFree) { + o.mListForFree = nullptr; + o.mHead = nullptr; + } + + BulkPoolAllocator& operator=(BulkPoolAllocator&& o) noexcept { + reset(); + mHead = o.mHead; + mListForFree = o.mListForFree; + o.mListForFree = nullptr; + o.mHead = nullptr; + return *this; + } + + BulkPoolAllocator& + // NOLINTNEXTLINE(bugprone-unhandled-self-assignment,cert-oop54-cpp) + operator=(const BulkPoolAllocator& ROBIN_HOOD_UNUSED(o) /*unused*/) noexcept { + // does not do anything + return *this; + } + + ~BulkPoolAllocator() noexcept { + reset(); + } + + // Deallocates all allocated memory. + void reset() noexcept { + while (mListForFree) { + T* tmp = *mListForFree; + ROBIN_HOOD_LOG("std::free") + std::free(mListForFree); + mListForFree = reinterpret_cast_no_cast_align_warning(tmp); + } + mHead = nullptr; + } + + // allocates, but does NOT initialize. Use in-place new constructor, e.g. + // T* obj = pool.allocate(); + // ::new (static_cast(obj)) T(); + T* allocate() { + T* tmp = mHead; + if (!tmp) { + tmp = performAllocation(); + } + + mHead = *reinterpret_cast_no_cast_align_warning(tmp); + return tmp; + } + + // does not actually deallocate but puts it in store. + // make sure you have already called the destructor! e.g. with + // obj->~T(); + // pool.deallocate(obj); + void deallocate(T* obj) noexcept { + *reinterpret_cast_no_cast_align_warning(obj) = mHead; + mHead = obj; + } + + // Adds an already allocated block of memory to the allocator. This allocator is from now on + // responsible for freeing the data (with free()). If the provided data is not large enough to + // make use of, it is immediately freed. Otherwise it is reused and freed in the destructor. + void addOrFree(void* ptr, const size_t numBytes) noexcept { + // calculate number of available elements in ptr + if (numBytes < ALIGNMENT + ALIGNED_SIZE) { + // not enough data for at least one element. Free and return. + ROBIN_HOOD_LOG("std::free") + std::free(ptr); + } else { + ROBIN_HOOD_LOG("add to buffer") + add(ptr, numBytes); + } + } + + void swap(BulkPoolAllocator& other) noexcept { + using std::swap; + swap(mHead, other.mHead); + swap(mListForFree, other.mListForFree); + } + +private: + // iterates the list of allocated memory to calculate how many to alloc next. + // Recalculating this each time saves us a size_t member. + // This ignores the fact that memory blocks might have been added manually with addOrFree. In + // practice, this should not matter much. + ROBIN_HOOD(NODISCARD) size_t calcNumElementsToAlloc() const noexcept { + auto tmp = mListForFree; + size_t numAllocs = MinNumAllocs; + + while (numAllocs * 2 <= MaxNumAllocs && tmp) { + auto x = reinterpret_cast(tmp); + tmp = *x; + numAllocs *= 2; + } + + return numAllocs; + } + + // WARNING: Underflow if numBytes < ALIGNMENT! This is guarded in addOrFree(). + void add(void* ptr, const size_t numBytes) noexcept { + const size_t numElements = (numBytes - ALIGNMENT) / ALIGNED_SIZE; + + auto data = reinterpret_cast(ptr); + + // link free list + auto x = reinterpret_cast(data); + *x = mListForFree; + mListForFree = data; + + // create linked list for newly allocated data + auto* const headT = + reinterpret_cast_no_cast_align_warning(reinterpret_cast(ptr) + ALIGNMENT); + + auto* const head = reinterpret_cast(headT); + + // Visual Studio compiler automatically unrolls this loop, which is pretty cool + for (size_t i = 0; i < numElements; ++i) { + *reinterpret_cast_no_cast_align_warning(head + i * ALIGNED_SIZE) = + head + (i + 1) * ALIGNED_SIZE; + } + + // last one points to 0 + *reinterpret_cast_no_cast_align_warning(head + (numElements - 1) * ALIGNED_SIZE) = + mHead; + mHead = headT; + } + + // Called when no memory is available (mHead == 0). + // Don't inline this slow path. + ROBIN_HOOD(NOINLINE) T* performAllocation() { + size_t const numElementsToAlloc = calcNumElementsToAlloc(); + + // alloc new memory: [prev |T, T, ... T] + size_t const bytes = ALIGNMENT + ALIGNED_SIZE * numElementsToAlloc; + ROBIN_HOOD_LOG("std::malloc " << bytes << " = " << ALIGNMENT << " + " << ALIGNED_SIZE + << " * " << numElementsToAlloc) + add(assertNotNull(std::malloc(bytes)), bytes); + return mHead; + } + + // enforce byte alignment of the T's +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX14) + static constexpr size_t ALIGNMENT = + (std::max)(std::alignment_of::value, std::alignment_of::value); +#else + static const size_t ALIGNMENT = + (ROBIN_HOOD_STD::alignment_of::value > ROBIN_HOOD_STD::alignment_of::value) + ? ROBIN_HOOD_STD::alignment_of::value + : +ROBIN_HOOD_STD::alignment_of::value; // the + is for walkarround +#endif + + static constexpr size_t ALIGNED_SIZE = ((sizeof(T) - 1) / ALIGNMENT + 1) * ALIGNMENT; + + static_assert(MinNumAllocs >= 1, "MinNumAllocs"); + static_assert(MaxNumAllocs >= MinNumAllocs, "MaxNumAllocs"); + static_assert(ALIGNED_SIZE >= sizeof(T*), "ALIGNED_SIZE"); + static_assert(0 == (ALIGNED_SIZE % sizeof(T*)), "ALIGNED_SIZE mod"); + static_assert(ALIGNMENT >= sizeof(T*), "ALIGNMENT"); + + T* mHead{nullptr}; + T** mListForFree{nullptr}; +}; + +template +struct NodeAllocator; + +// dummy allocator that does nothing +template +struct NodeAllocator { + + // we are not using the data, so just free it. + void addOrFree(void* ptr, size_t ROBIN_HOOD_UNUSED(numBytes) /*unused*/) noexcept { + ROBIN_HOOD_LOG("std::free") + std::free(ptr); + } +}; + +template +struct NodeAllocator : public BulkPoolAllocator {}; + +// c++14 doesn't have is_nothrow_swappable, and clang++ 6.0.1 doesn't like it either, so I'm making +// my own here. +namespace swappable { +#if ROBIN_HOOD(CXX) < ROBIN_HOOD(CXX17) +using std::swap; +template +struct nothrow { + static const bool value = noexcept(swap(std::declval(), std::declval())); +}; +#else +template +struct nothrow { + static const bool value = std::is_nothrow_swappable::value; +}; +#endif +} // namespace swappable + +} // namespace detail + +struct is_transparent_tag {}; + +// A custom pair implementation is used in the map because std::pair is not is_trivially_copyable, +// which means it would not be allowed to be used in std::memcpy. This struct is copyable, which is +// also tested. +template +struct pair { + using first_type = T1; + using second_type = T2; + + template ::value && + std::is_default_constructible::value>::type> + constexpr pair() noexcept(noexcept(U1()) && noexcept(U2())) + : first() + , second() {} + + // pair constructors are explicit so we don't accidentally call this ctor when we don't have to. + explicit constexpr pair(std::pair const& o) noexcept( + noexcept(T1(std::declval())) && noexcept(T2(std::declval()))) + : first(o.first) + , second(o.second) {} + + // pair constructors are explicit so we don't accidentally call this ctor when we don't have to. + explicit constexpr pair(std::pair&& o) noexcept(noexcept( + T1(std::move(std::declval()))) && noexcept(T2(std::move(std::declval())))) + : first(std::move(o.first)) + , second(std::move(o.second)) {} + + constexpr pair(T1&& a, T2&& b) noexcept(noexcept( + T1(std::move(std::declval()))) && noexcept(T2(std::move(std::declval())))) + : first(std::move(a)) + , second(std::move(b)) {} + + template + constexpr pair(U1&& a, U2&& b) noexcept(noexcept(T1(std::forward( + std::declval()))) && noexcept(T2(std::forward(std::declval())))) + : first(std::forward(a)) + , second(std::forward(b)) {} + + template + // MSVC 2015 produces error "C2476: ‘constexpr’ constructor does not initialize all members" + // if this constructor is constexpr +#if !ROBIN_HOOD(BROKEN_CONSTEXPR) + constexpr +#endif + pair(std::piecewise_construct_t /*unused*/, std::tuple a, + std::tuple + b) noexcept(noexcept(pair(std::declval&>(), + std::declval&>(), + ROBIN_HOOD_STD::index_sequence_for(), + ROBIN_HOOD_STD::index_sequence_for()))) + : pair(a, b, ROBIN_HOOD_STD::index_sequence_for(), + ROBIN_HOOD_STD::index_sequence_for()) { + } + + // constructor called from the std::piecewise_construct_t ctor + template + pair(std::tuple& a, std::tuple& b, ROBIN_HOOD_STD::index_sequence /*unused*/, ROBIN_HOOD_STD::index_sequence /*unused*/) noexcept( + noexcept(T1(std::forward(std::get( + std::declval&>()))...)) && noexcept(T2(std:: + forward(std::get( + std::declval&>()))...))) + : first(std::forward(std::get(a))...) + , second(std::forward(std::get(b))...) { + // make visual studio compiler happy about warning about unused a & b. + // Visual studio's pair implementation disables warning 4100. + (void)a; + (void)b; + } + + void swap(pair& o) noexcept((detail::swappable::nothrow::value) && + (detail::swappable::nothrow::value)) { + using std::swap; + swap(first, o.first); + swap(second, o.second); + } + + T1 first; // NOLINT(misc-non-private-member-variables-in-classes) + T2 second; // NOLINT(misc-non-private-member-variables-in-classes) +}; + +template +inline void swap(pair& a, pair& b) noexcept( + noexcept(std::declval&>().swap(std::declval&>()))) { + a.swap(b); +} + +template +inline constexpr bool operator==(pair const& x, pair const& y) { + return (x.first == y.first) && (x.second == y.second); +} +template +inline constexpr bool operator!=(pair const& x, pair const& y) { + return !(x == y); +} +template +inline constexpr bool operator<(pair const& x, pair const& y) noexcept(noexcept( + std::declval() < std::declval()) && noexcept(std::declval() < + std::declval())) { + return x.first < y.first || (!(y.first < x.first) && x.second < y.second); +} +template +inline constexpr bool operator>(pair const& x, pair const& y) { + return y < x; +} +template +inline constexpr bool operator<=(pair const& x, pair const& y) { + return !(x > y); +} +template +inline constexpr bool operator>=(pair const& x, pair const& y) { + return !(x < y); +} + +inline size_t hash_bytes(void const* ptr, size_t len) noexcept { + static constexpr uint64_t m = UINT64_C(0xc6a4a7935bd1e995); + static constexpr uint64_t seed = UINT64_C(0xe17a1465); + static constexpr unsigned int r = 47; + + auto const* const data64 = static_cast(ptr); + uint64_t h = seed ^ (len * m); + + size_t const n_blocks = len / 8; + for (size_t i = 0; i < n_blocks; ++i) { + auto k = detail::unaligned_load(data64 + i); + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + auto const* const data8 = reinterpret_cast(data64 + n_blocks); + switch (len & 7U) { + case 7: + h ^= static_cast(data8[6]) << 48U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 6: + h ^= static_cast(data8[5]) << 40U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 5: + h ^= static_cast(data8[4]) << 32U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 4: + h ^= static_cast(data8[3]) << 24U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 3: + h ^= static_cast(data8[2]) << 16U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 2: + h ^= static_cast(data8[1]) << 8U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 1: + h ^= static_cast(data8[0]); + h *= m; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + default: + break; + } + + h ^= h >> r; + + // not doing the final step here, because this will be done by keyToIdx anyways + // h *= m; + // h ^= h >> r; + return static_cast(h); +} + +inline size_t hash_int(uint64_t x) noexcept { + // tried lots of different hashes, let's stick with murmurhash3. It's simple, fast, well tested, + // and doesn't need any special 128bit operations. + x ^= x >> 33U; + x *= UINT64_C(0xff51afd7ed558ccd); + x ^= x >> 33U; + + // not doing the final step here, because this will be done by keyToIdx anyways + // x *= UINT64_C(0xc4ceb9fe1a85ec53); + // x ^= x >> 33U; + return static_cast(x); +} + +// A thin wrapper around std::hash, performing an additional simple mixing step of the result. +template +struct hash : public std::hash { + size_t operator()(T const& obj) const + noexcept(noexcept(std::declval>().operator()(std::declval()))) { + // call base hash + auto result = std::hash::operator()(obj); + // return mixed of that, to be save against identity has + return hash_int(static_cast(result)); + } +}; + +template +struct hash> { + size_t operator()(std::basic_string const& str) const noexcept { + return hash_bytes(str.data(), sizeof(CharT) * str.size()); + } +}; + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX17) +template +struct hash> { + size_t operator()(std::basic_string_view const& sv) const noexcept { + return hash_bytes(sv.data(), sizeof(CharT) * sv.size()); + } +}; +#endif + +template +struct hash { + size_t operator()(T* ptr) const noexcept { + return hash_int(reinterpret_cast(ptr)); + } +}; + +template +struct hash> { + size_t operator()(std::unique_ptr const& ptr) const noexcept { + return hash_int(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash> { + size_t operator()(std::shared_ptr const& ptr) const noexcept { + return hash_int(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash::value>::type> { + size_t operator()(Enum e) const noexcept { + using Underlying = typename std::underlying_type::type; + return hash{}(static_cast(e)); + } +}; + +#define ROBIN_HOOD_HASH_INT(T) \ + template <> \ + struct hash { \ + size_t operator()(T const& obj) const noexcept { \ + return hash_int(static_cast(obj)); \ + } \ + } + +#if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuseless-cast" +#endif +// see https://en.cppreference.com/w/cpp/utility/hash +ROBIN_HOOD_HASH_INT(bool); +ROBIN_HOOD_HASH_INT(char); +ROBIN_HOOD_HASH_INT(signed char); +ROBIN_HOOD_HASH_INT(unsigned char); +ROBIN_HOOD_HASH_INT(char16_t); +ROBIN_HOOD_HASH_INT(char32_t); +#if ROBIN_HOOD(HAS_NATIVE_WCHART) +ROBIN_HOOD_HASH_INT(wchar_t); +#endif +ROBIN_HOOD_HASH_INT(short); +ROBIN_HOOD_HASH_INT(unsigned short); +ROBIN_HOOD_HASH_INT(int); +ROBIN_HOOD_HASH_INT(unsigned int); +ROBIN_HOOD_HASH_INT(long); +ROBIN_HOOD_HASH_INT(long long); +ROBIN_HOOD_HASH_INT(unsigned long); +ROBIN_HOOD_HASH_INT(unsigned long long); +#if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif +namespace detail { + +template +struct void_type { + using type = void; +}; + +template +struct has_is_transparent : public std::false_type {}; + +template +struct has_is_transparent::type> + : public std::true_type {}; + +// using wrapper classes for hash and key_equal prevents the diamond problem when the same type +// is used. see https://stackoverflow.com/a/28771920/48181 +template +struct WrapHash : public T { + WrapHash() = default; + explicit WrapHash(T const& o) noexcept(noexcept(T(std::declval()))) + : T(o) {} +}; + +template +struct WrapKeyEqual : public T { + WrapKeyEqual() = default; + explicit WrapKeyEqual(T const& o) noexcept(noexcept(T(std::declval()))) + : T(o) {} +}; + +// A highly optimized hashmap implementation, using the Robin Hood algorithm. +// +// In most cases, this map should be usable as a drop-in replacement for std::unordered_map, but +// be about 2x faster in most cases and require much less allocations. +// +// This implementation uses the following memory layout: +// +// [Node, Node, ... Node | info, info, ... infoSentinel ] +// +// * Node: either a DataNode that directly has the std::pair as member, +// or a DataNode with a pointer to std::pair. Which DataNode representation to use +// depends on how fast the swap() operation is. Heuristically, this is automatically choosen +// based on sizeof(). there are always 2^n Nodes. +// +// * info: Each Node in the map has a corresponding info byte, so there are 2^n info bytes. +// Each byte is initialized to 0, meaning the corresponding Node is empty. Set to 1 means the +// corresponding node contains data. Set to 2 means the corresponding Node is filled, but it +// actually belongs to the previous position and was pushed out because that place is already +// taken. +// +// * infoSentinel: Sentinel byte set to 1, so that iterator's ++ can stop at end() without the +// need for a idx variable. +// +// According to STL, order of templates has effect on throughput. That's why I've moved the +// boolean to the front. +// https://www.reddit.com/r/cpp/comments/ahp6iu/compile_time_binary_size_reductions_and_cs_future/eeguck4/ +template +class Table + : public WrapHash, + public WrapKeyEqual, + detail::NodeAllocator< + typename std::conditional< + std::is_void::value, Key, + robin_hood::pair::type, T>>::type, + 4, 16384, IsFlat> { +public: + static constexpr bool is_flat = IsFlat; + static constexpr bool is_map = !std::is_void::value; + static constexpr bool is_set = !is_map; + static constexpr bool is_transparent = + has_is_transparent::value && has_is_transparent::value; + + using key_type = Key; + using mapped_type = T; + using value_type = typename std::conditional< + is_set, Key, + robin_hood::pair::type, T>>::type; + using size_type = size_t; + using hasher = Hash; + using key_equal = KeyEqual; + using Self = Table; + +private: + static_assert(MaxLoadFactor100 > 10 && MaxLoadFactor100 < 100, + "MaxLoadFactor100 needs to be >10 && < 100"); + + using WHash = WrapHash; + using WKeyEqual = WrapKeyEqual; + + // configuration defaults + + // make sure we have 8 elements, needed to quickly rehash mInfo + static constexpr size_t InitialNumElements = sizeof(uint64_t); + static constexpr uint32_t InitialInfoNumBits = 5; + static constexpr uint8_t InitialInfoInc = 1U << InitialInfoNumBits; + static constexpr size_t InfoMask = InitialInfoInc - 1U; + static constexpr uint8_t InitialInfoHashShift = 0; + using DataPool = detail::NodeAllocator; + + // type needs to be wider than uint8_t. + using InfoType = uint32_t; + + // DataNode //////////////////////////////////////////////////////// + + // Primary template for the data node. We have special implementations for small and big + // objects. For large objects it is assumed that swap() is fairly slow, so we allocate these + // on the heap so swap merely swaps a pointer. + template + class DataNode {}; + + // Small: just allocate on the stack. + template + class DataNode final { + public: + template + explicit DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, Args&&... args) noexcept( + noexcept(value_type(std::forward(args)...))) + : mData(std::forward(args)...) {} + + DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, DataNode&& n) noexcept( + std::is_nothrow_move_constructible::value) + : mData(std::move(n.mData)) {} + + // doesn't do anything + void destroy(M& ROBIN_HOOD_UNUSED(map) /*unused*/) noexcept {} + void destroyDoNotDeallocate() noexcept {} + + value_type const* operator->() const noexcept { + return &mData; + } + value_type* operator->() noexcept { + return &mData; + } + + const value_type& operator*() const noexcept { + return mData; + } + + value_type& operator*() noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData.first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type + getFirst() const noexcept { + return mData.first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() const noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() noexcept { + return mData.second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() const noexcept { + return mData.second; + } + + void swap(DataNode& o) noexcept( + noexcept(std::declval().swap(std::declval()))) { + mData.swap(o.mData); + } + + private: + value_type mData; + }; + + // big object: allocate on heap. + template + class DataNode { + public: + template + explicit DataNode(M& map, Args&&... args) + : mData(map.allocate()) { + ::new (static_cast(mData)) value_type(std::forward(args)...); + } + + DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, DataNode&& n) noexcept + : mData(std::move(n.mData)) {} + + void destroy(M& map) noexcept { + // don't deallocate, just put it into list of datapool. + mData->~value_type(); + map.deallocate(mData); + } + + void destroyDoNotDeallocate() noexcept { + mData->~value_type(); + } + + value_type const* operator->() const noexcept { + return mData; + } + + value_type* operator->() noexcept { + return mData; + } + + const value_type& operator*() const { + return *mData; + } + + value_type& operator*() { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData->first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type + getFirst() const noexcept { + return mData->first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() const noexcept { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() noexcept { + return mData->second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() const noexcept { + return mData->second; + } + + void swap(DataNode& o) noexcept { + using std::swap; + swap(mData, o.mData); + } + + private: + value_type* mData; + }; + + using Node = DataNode; + + // helpers for insertKeyPrepareEmptySpot: extract first entry (only const required) + ROBIN_HOOD(NODISCARD) key_type const& getFirstConst(Node const& n) const noexcept { + return n.getFirst(); + } + + // in case we have void mapped_type, we are not using a pair, thus we just route k through. + // No need to disable this because it's just not used if not applicable. + ROBIN_HOOD(NODISCARD) key_type const& getFirstConst(key_type const& k) const noexcept { + return k; + } + + // in case we have non-void mapped_type, we have a standard robin_hood::pair + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, key_type const&>::type + getFirstConst(value_type const& vt) const noexcept { + return vt.first; + } + + // Cloner ////////////////////////////////////////////////////////// + + template + struct Cloner; + + // fast path: Just copy data, without allocating anything. + template + struct Cloner { + void operator()(M const& source, M& target) const { + auto const* const src = reinterpret_cast(source.mKeyVals); + auto* tgt = reinterpret_cast(target.mKeyVals); + auto const numElementsWithBuffer = target.calcNumElementsWithBuffer(target.mMask + 1); + std::copy(src, src + target.calcNumBytesTotal(numElementsWithBuffer), tgt); + } + }; + + template + struct Cloner { + void operator()(M const& s, M& t) const { + auto const numElementsWithBuffer = t.calcNumElementsWithBuffer(t.mMask + 1); + std::copy(s.mInfo, s.mInfo + t.calcNumBytesInfo(numElementsWithBuffer), t.mInfo); + + for (size_t i = 0; i < numElementsWithBuffer; ++i) { + if (t.mInfo[i]) { + ::new (static_cast(t.mKeyVals + i)) Node(t, *s.mKeyVals[i]); + } + } + } + }; + + // Destroyer /////////////////////////////////////////////////////// + + template + struct Destroyer {}; + + template + struct Destroyer { + void nodes(M& m) const noexcept { + m.mNumElements = 0; + } + + void nodesDoNotDeallocate(M& m) const noexcept { + m.mNumElements = 0; + } + }; + + template + struct Destroyer { + void nodes(M& m) const noexcept { + m.mNumElements = 0; + // clear also resets mInfo to 0, that's sometimes not necessary. + auto const numElementsWithBuffer = m.calcNumElementsWithBuffer(m.mMask + 1); + + for (size_t idx = 0; idx < numElementsWithBuffer; ++idx) { + if (0 != m.mInfo[idx]) { + Node& n = m.mKeyVals[idx]; + n.destroy(m); + n.~Node(); + } + } + } + + void nodesDoNotDeallocate(M& m) const noexcept { + m.mNumElements = 0; + // clear also resets mInfo to 0, that's sometimes not necessary. + auto const numElementsWithBuffer = m.calcNumElementsWithBuffer(m.mMask + 1); + for (size_t idx = 0; idx < numElementsWithBuffer; ++idx) { + if (0 != m.mInfo[idx]) { + Node& n = m.mKeyVals[idx]; + n.destroyDoNotDeallocate(); + n.~Node(); + } + } + } + }; + + // Iter //////////////////////////////////////////////////////////// + + struct fast_forward_tag {}; + + // generic iterator for both const_iterator and iterator. + template + // NOLINTNEXTLINE(hicpp-special-member-functions,cppcoreguidelines-special-member-functions) + class Iter { + private: + using NodePtr = typename std::conditional::type; + + public: + using difference_type = std::ptrdiff_t; + using value_type = typename Self::value_type; + using reference = typename std::conditional::type; + using pointer = typename std::conditional::type; + using iterator_category = std::forward_iterator_tag; + + // default constructed iterator can be compared to itself, but WON'T return true when + // compared to end(). + Iter() = default; + + // Rule of zero: nothing specified. The conversion constructor is only enabled for + // iterator to const_iterator, so it doesn't accidentally work as a copy ctor. + + // Conversion constructor from iterator to const_iterator. + template ::type> + // NOLINTNEXTLINE(hicpp-explicit-conversions) + Iter(Iter const& other) noexcept + : mKeyVals(other.mKeyVals) + , mInfo(other.mInfo) {} + + Iter(NodePtr valPtr, uint8_t const* infoPtr) noexcept + : mKeyVals(valPtr) + , mInfo(infoPtr) {} + + Iter(NodePtr valPtr, uint8_t const* infoPtr, + fast_forward_tag ROBIN_HOOD_UNUSED(tag) /*unused*/) noexcept + : mKeyVals(valPtr) + , mInfo(infoPtr) { + fastForward(); + } + + template ::type> + Iter& operator=(Iter const& other) noexcept { + mKeyVals = other.mKeyVals; + mInfo = other.mInfo; + return *this; + } + + // prefix increment. Undefined behavior if we are at end()! + Iter& operator++() noexcept { + mInfo++; + mKeyVals++; + fastForward(); + return *this; + } + + Iter operator++(int) noexcept { + Iter tmp = *this; + ++(*this); + return tmp; + } + + reference operator*() const { + return **mKeyVals; + } + + pointer operator->() const { + return &**mKeyVals; + } + + template + bool operator==(Iter const& o) const noexcept { + return mKeyVals == o.mKeyVals; + } + + template + bool operator!=(Iter const& o) const noexcept { + return mKeyVals != o.mKeyVals; + } + + private: + // fast forward to the next non-free info byte + // I've tried a few variants that don't depend on intrinsics, but unfortunately they are + // quite a bit slower than this one. So I've reverted that change again. See map_benchmark. + void fastForward() noexcept { + size_t n = 0; + while (0U == (n = detail::unaligned_load(mInfo))) { + mInfo += sizeof(size_t); + mKeyVals += sizeof(size_t); + } +#if defined(ROBIN_HOOD_DISABLE_INTRINSICS) + // we know for certain that within the next 8 bytes we'll find a non-zero one. + if (ROBIN_HOOD_UNLIKELY(0U == detail::unaligned_load(mInfo))) { + mInfo += 4; + mKeyVals += 4; + } + if (ROBIN_HOOD_UNLIKELY(0U == detail::unaligned_load(mInfo))) { + mInfo += 2; + mKeyVals += 2; + } + if (ROBIN_HOOD_UNLIKELY(0U == *mInfo)) { + mInfo += 1; + mKeyVals += 1; + } +#else +# if ROBIN_HOOD(LITTLE_ENDIAN) + auto inc = ROBIN_HOOD_COUNT_TRAILING_ZEROES(n) / 8; +# else + auto inc = ROBIN_HOOD_COUNT_LEADING_ZEROES(n) / 8; +# endif + mInfo += inc; + mKeyVals += inc; +#endif + } + + friend class Table; + NodePtr mKeyVals{nullptr}; + uint8_t const* mInfo{nullptr}; + }; + + //////////////////////////////////////////////////////////////////// + + // highly performance relevant code. + // Lower bits are used for indexing into the array (2^n size) + // The upper 1-5 bits need to be a reasonable good hash, to save comparisons. + template + void keyToIdx(HashKey&& key, size_t* idx, InfoType* info) const { + // In addition to whatever hash is used, add another mul & shift so we get better hashing. + // This serves as a bad hash prevention, if the given data is + // badly mixed. + auto h = static_cast(WHash::operator()(key)); + + h *= mHashMultiplier; + h ^= h >> 33U; + + // the lower InitialInfoNumBits are reserved for info. + *info = mInfoInc + static_cast((h & InfoMask) >> mInfoHashShift); + *idx = (static_cast(h) >> InitialInfoNumBits) & mMask; + } + + // forwards the index by one, wrapping around at the end + void next(InfoType* info, size_t* idx) const noexcept { + *idx = *idx + 1; + *info += mInfoInc; + } + + void nextWhileLess(InfoType* info, size_t* idx) const noexcept { + // unrolling this by hand did not bring any speedups. + while (*info < mInfo[*idx]) { + next(info, idx); + } + } + + // Shift everything up by one element. Tries to move stuff around. + void + shiftUp(size_t startIdx, + size_t const insertion_idx) noexcept(std::is_nothrow_move_assignable::value) { + auto idx = startIdx; + ::new (static_cast(mKeyVals + idx)) Node(std::move(mKeyVals[idx - 1])); + while (--idx != insertion_idx) { + mKeyVals[idx] = std::move(mKeyVals[idx - 1]); + } + + idx = startIdx; + while (idx != insertion_idx) { + ROBIN_HOOD_COUNT(shiftUp) + mInfo[idx] = static_cast(mInfo[idx - 1] + mInfoInc); + if (ROBIN_HOOD_UNLIKELY(mInfo[idx] + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + --idx; + } + } + + void shiftDown(size_t idx) noexcept(std::is_nothrow_move_assignable::value) { + // until we find one that is either empty or has zero offset. + // TODO(martinus) we don't need to move everything, just the last one for the same + // bucket. + mKeyVals[idx].destroy(*this); + + // until we find one that is either empty or has zero offset. + while (mInfo[idx + 1] >= 2 * mInfoInc) { + ROBIN_HOOD_COUNT(shiftDown) + mInfo[idx] = static_cast(mInfo[idx + 1] - mInfoInc); + mKeyVals[idx] = std::move(mKeyVals[idx + 1]); + ++idx; + } + + mInfo[idx] = 0; + // don't destroy, we've moved it + // mKeyVals[idx].destroy(*this); + mKeyVals[idx].~Node(); + } + + // copy of find(), except that it returns iterator instead of const_iterator. + template + ROBIN_HOOD(NODISCARD) + size_t findIdx(Other const& key) const { + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + + do { + // unrolling this twice gives a bit of a speedup. More unrolling did not help. + if (info == mInfo[idx] && + ROBIN_HOOD_LIKELY(WKeyEqual::operator()(key, mKeyVals[idx].getFirst()))) { + return idx; + } + next(&info, &idx); + if (info == mInfo[idx] && + ROBIN_HOOD_LIKELY(WKeyEqual::operator()(key, mKeyVals[idx].getFirst()))) { + return idx; + } + next(&info, &idx); + } while (info <= mInfo[idx]); + + // nothing found! + return mMask == 0 ? 0 + : static_cast(std::distance( + mKeyVals, reinterpret_cast_no_cast_align_warning(mInfo))); + } + + void cloneData(const Table& o) { + Cloner()(o, *this); + } + + // inserts a keyval that is guaranteed to be new, e.g. when the hashmap is resized. + // @return True on success, false if something went wrong + void insert_move(Node&& keyval) { + // we don't retry, fail if overflowing + // don't need to check max num elements + if (0 == mMaxNumElementsAllowed && !try_increase_info()) { + throwOverflowError(); + } + + size_t idx{}; + InfoType info{}; + keyToIdx(keyval.getFirst(), &idx, &info); + + // skip forward. Use <= because we are certain that the element is not there. + while (info <= mInfo[idx]) { + idx = idx + 1; + info += mInfoInc; + } + + // key not found, so we are now exactly where we want to insert it. + auto const insertion_idx = idx; + auto const insertion_info = static_cast(info); + if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + + // find an empty spot + while (0 != mInfo[idx]) { + next(&info, &idx); + } + + auto& l = mKeyVals[insertion_idx]; + if (idx == insertion_idx) { + ::new (static_cast(&l)) Node(std::move(keyval)); + } else { + shiftUp(idx, insertion_idx); + l = std::move(keyval); + } + + // put at empty spot + mInfo[insertion_idx] = insertion_info; + + ++mNumElements; + } + +public: + using iterator = Iter; + using const_iterator = Iter; + + Table() noexcept(noexcept(Hash()) && noexcept(KeyEqual())) + : WHash() + , WKeyEqual() { + ROBIN_HOOD_TRACE(this) + } + + // Creates an empty hash map. Nothing is allocated yet, this happens at the first insert. + // This tremendously speeds up ctor & dtor of a map that never receives an element. The + // penalty is payed at the first insert, and not before. Lookup of this empty map works + // because everybody points to DummyInfoByte::b. parameter bucket_count is dictated by the + // standard, but we can ignore it. + explicit Table( + size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/, const Hash& h = Hash{}, + const KeyEqual& equal = KeyEqual{}) noexcept(noexcept(Hash(h)) && noexcept(KeyEqual(equal))) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + } + + template + Table(Iter first, Iter last, size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/ = 0, + const Hash& h = Hash{}, const KeyEqual& equal = KeyEqual{}) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + insert(first, last); + } + + Table(std::initializer_list initlist, + size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/ = 0, const Hash& h = Hash{}, + const KeyEqual& equal = KeyEqual{}) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + insert(initlist.begin(), initlist.end()); + } + + Table(Table&& o) noexcept + : WHash(std::move(static_cast(o))) + , WKeyEqual(std::move(static_cast(o))) + , DataPool(std::move(static_cast(o))) { + ROBIN_HOOD_TRACE(this) + if (o.mMask) { + mHashMultiplier = std::move(o.mHashMultiplier); + mKeyVals = std::move(o.mKeyVals); + mInfo = std::move(o.mInfo); + mNumElements = std::move(o.mNumElements); + mMask = std::move(o.mMask); + mMaxNumElementsAllowed = std::move(o.mMaxNumElementsAllowed); + mInfoInc = std::move(o.mInfoInc); + mInfoHashShift = std::move(o.mInfoHashShift); + // set other's mask to 0 so its destructor won't do anything + o.init(); + } + } + + Table& operator=(Table&& o) noexcept { + ROBIN_HOOD_TRACE(this) + if (&o != this) { + if (o.mMask) { + // only move stuff if the other map actually has some data + destroy(); + mHashMultiplier = std::move(o.mHashMultiplier); + mKeyVals = std::move(o.mKeyVals); + mInfo = std::move(o.mInfo); + mNumElements = std::move(o.mNumElements); + mMask = std::move(o.mMask); + mMaxNumElementsAllowed = std::move(o.mMaxNumElementsAllowed); + mInfoInc = std::move(o.mInfoInc); + mInfoHashShift = std::move(o.mInfoHashShift); + WHash::operator=(std::move(static_cast(o))); + WKeyEqual::operator=(std::move(static_cast(o))); + DataPool::operator=(std::move(static_cast(o))); + + o.init(); + + } else { + // nothing in the other map => just clear us. + clear(); + } + } + return *this; + } + + Table(const Table& o) + : WHash(static_cast(o)) + , WKeyEqual(static_cast(o)) + , DataPool(static_cast(o)) { + ROBIN_HOOD_TRACE(this) + if (!o.empty()) { + // not empty: create an exact copy. it is also possible to just iterate through all + // elements and insert them, but copying is probably faster. + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mHashMultiplier = o.mHashMultiplier; + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + // no need for calloc because clonData does memcpy + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + mNumElements = o.mNumElements; + mMask = o.mMask; + mMaxNumElementsAllowed = o.mMaxNumElementsAllowed; + mInfoInc = o.mInfoInc; + mInfoHashShift = o.mInfoHashShift; + cloneData(o); + } + } + + // Creates a copy of the given map. Copy constructor of each entry is used. + // Not sure why clang-tidy thinks this doesn't handle self assignment, it does + // NOLINTNEXTLINE(bugprone-unhandled-self-assignment,cert-oop54-cpp) + Table& operator=(Table const& o) { + ROBIN_HOOD_TRACE(this) + if (&o == this) { + // prevent assigning of itself + return *this; + } + + // we keep using the old allocator and not assign the new one, because we want to keep + // the memory available. when it is the same size. + if (o.empty()) { + if (0 == mMask) { + // nothing to do, we are empty too + return *this; + } + + // not empty: destroy what we have there + // clear also resets mInfo to 0, that's sometimes not necessary. + destroy(); + init(); + WHash::operator=(static_cast(o)); + WKeyEqual::operator=(static_cast(o)); + DataPool::operator=(static_cast(o)); + + return *this; + } + + // clean up old stuff + Destroyer::value>{}.nodes(*this); + + if (mMask != o.mMask) { + // no luck: we don't have the same array size allocated, so we need to realloc. + if (0 != mMask) { + // only deallocate if we actually have data! + ROBIN_HOOD_LOG("std::free") + std::free(mKeyVals); + } + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + + // no need for calloc here because cloneData performs a memcpy. + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + // sentinel is set in cloneData + } + WHash::operator=(static_cast(o)); + WKeyEqual::operator=(static_cast(o)); + DataPool::operator=(static_cast(o)); + mHashMultiplier = o.mHashMultiplier; + mNumElements = o.mNumElements; + mMask = o.mMask; + mMaxNumElementsAllowed = o.mMaxNumElementsAllowed; + mInfoInc = o.mInfoInc; + mInfoHashShift = o.mInfoHashShift; + cloneData(o); + + return *this; + } + + // Swaps everything between the two maps. + void swap(Table& o) { + ROBIN_HOOD_TRACE(this) + using std::swap; + swap(o, *this); + } + + // Clears all data, without resizing. + void clear() { + ROBIN_HOOD_TRACE(this) + if (empty()) { + // don't do anything! also important because we don't want to write to + // DummyInfoByte::b, even though we would just write 0 to it. + return; + } + + Destroyer::value>{}.nodes(*this); + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + // clear everything, then set the sentinel again + uint8_t const z = 0; + std::fill(mInfo, mInfo + calcNumBytesInfo(numElementsWithBuffer), z); + mInfo[numElementsWithBuffer] = 1; + + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + // Destroys the map and all it's contents. + ~Table() { + ROBIN_HOOD_TRACE(this) + destroy(); + } + + // Checks if both tables contain the same entries. Order is irrelevant. + bool operator==(const Table& other) const { + ROBIN_HOOD_TRACE(this) + if (other.size() != size()) { + return false; + } + for (auto const& otherEntry : other) { + if (!has(otherEntry)) { + return false; + } + } + + return true; + } + + bool operator!=(const Table& other) const { + ROBIN_HOOD_TRACE(this) + return !operator==(other); + } + + template + typename std::enable_if::value, Q&>::type operator[](const key_type& key) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(key), std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); + } + + template + typename std::enable_if::value, Q&>::type operator[](key_type&& key) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); + } + + template + void insert(Iter first, Iter last) { + for (; first != last; ++first) { + // value_type ctor needed because this might be called with std::pair's + insert(value_type(*first)); + } + } + + void insert(std::initializer_list ilist) { + for (auto&& vt : ilist) { + insert(std::move(vt)); + } + } + + template + std::pair emplace(Args&&... args) { + ROBIN_HOOD_TRACE(this) + Node n{*this, std::forward(args)...}; + auto idxAndState = insertKeyPrepareEmptySpot(getFirstConst(n)); + switch (idxAndState.second) { + case InsertionState::key_found: + n.destroy(*this); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node(*this, std::move(n)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = std::move(n); + break; + + case InsertionState::overflow_error: + n.destroy(*this); + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + template + iterator emplace_hint(const_iterator position, Args&&... args) { + (void)position; + return emplace(std::forward(args)...).first; + } + + template + std::pair try_emplace(const key_type& key, Args&&... args) { + return try_emplace_impl(key, std::forward(args)...); + } + + template + std::pair try_emplace(key_type&& key, Args&&... args) { + return try_emplace_impl(std::move(key), std::forward(args)...); + } + + template + iterator try_emplace(const_iterator hint, const key_type& key, Args&&... args) { + (void)hint; + return try_emplace_impl(key, std::forward(args)...).first; + } + + template + iterator try_emplace(const_iterator hint, key_type&& key, Args&&... args) { + (void)hint; + return try_emplace_impl(std::move(key), std::forward(args)...).first; + } + + template + std::pair insert_or_assign(const key_type& key, Mapped&& obj) { + return insertOrAssignImpl(key, std::forward(obj)); + } + + template + std::pair insert_or_assign(key_type&& key, Mapped&& obj) { + return insertOrAssignImpl(std::move(key), std::forward(obj)); + } + + template + iterator insert_or_assign(const_iterator hint, const key_type& key, Mapped&& obj) { + (void)hint; + return insertOrAssignImpl(key, std::forward(obj)).first; + } + + template + iterator insert_or_assign(const_iterator hint, key_type&& key, Mapped&& obj) { + (void)hint; + return insertOrAssignImpl(std::move(key), std::forward(obj)).first; + } + + std::pair insert(const value_type& keyval) { + ROBIN_HOOD_TRACE(this) + return emplace(keyval); + } + + iterator insert(const_iterator hint, const value_type& keyval) { + (void)hint; + return emplace(keyval).first; + } + + std::pair insert(value_type&& keyval) { + return emplace(std::move(keyval)); + } + + iterator insert(const_iterator hint, value_type&& keyval) { + (void)hint; + return emplace(std::move(keyval)).first; + } + + // Returns 1 if key is found, 0 otherwise. + size_t count(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv != reinterpret_cast_no_cast_align_warning(mInfo)) { + return 1; + } + return 0; + } + + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::type count(const OtherKey& key) const { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv != reinterpret_cast_no_cast_align_warning(mInfo)) { + return 1; + } + return 0; + } + + bool contains(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + return 1U == count(key); + } + + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::type contains(const OtherKey& key) const { + return 1U == count(key); + } + + // Returns a reference to the value found for key. + // Throws std::out_of_range if element cannot be found + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::value, Q&>::type at(key_type const& key) { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv == reinterpret_cast_no_cast_align_warning(mInfo)) { + doThrow("key not found"); + } + return kv->getSecond(); + } + + // Returns a reference to the value found for key. + // Throws std::out_of_range if element cannot be found + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::value, Q const&>::type at(key_type const& key) const { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv == reinterpret_cast_no_cast_align_warning(mInfo)) { + doThrow("key not found"); + } + return kv->getSecond(); + } + + const_iterator find(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + template + const_iterator find(const OtherKey& key, is_transparent_tag /*unused*/) const { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + template + typename std::enable_if::type // NOLINT(modernize-use-nodiscard) + find(const OtherKey& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + iterator find(const key_type& key) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + template + iterator find(const OtherKey& key, is_transparent_tag /*unused*/) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + template + typename std::enable_if::type find(const OtherKey& key) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + iterator begin() { + ROBIN_HOOD_TRACE(this) + if (empty()) { + return end(); + } + return iterator(mKeyVals, mInfo, fast_forward_tag{}); + } + const_iterator begin() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return cbegin(); + } + const_iterator cbegin() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + if (empty()) { + return cend(); + } + return const_iterator(mKeyVals, mInfo, fast_forward_tag{}); + } + + iterator end() { + ROBIN_HOOD_TRACE(this) + // no need to supply valid info pointer: end() must not be dereferenced, and only node + // pointer is compared. + return iterator{reinterpret_cast_no_cast_align_warning(mInfo), nullptr}; + } + const_iterator end() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return cend(); + } + const_iterator cend() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return const_iterator{reinterpret_cast_no_cast_align_warning(mInfo), nullptr}; + } + + iterator erase(const_iterator pos) { + ROBIN_HOOD_TRACE(this) + // its safe to perform const cast here + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return erase(iterator{const_cast(pos.mKeyVals), const_cast(pos.mInfo)}); + } + + // Erases element at pos, returns iterator to the next element. + iterator erase(iterator pos) { + ROBIN_HOOD_TRACE(this) + // we assume that pos always points to a valid entry, and not end(). + auto const idx = static_cast(pos.mKeyVals - mKeyVals); + + shiftDown(idx); + --mNumElements; + + if (*pos.mInfo) { + // we've backward shifted, return this again + return pos; + } + + // no backward shift, return next element + return ++pos; + } + + size_t erase(const key_type& key) { + ROBIN_HOOD_TRACE(this) + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + + // check while info matches with the source idx + do { + if (info == mInfo[idx] && WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { + shiftDown(idx); + --mNumElements; + return 1; + } + next(&info, &idx); + } while (info <= mInfo[idx]); + + // nothing found to delete + return 0; + } + + // reserves space for the specified number of elements. Makes sure the old data fits. + // exactly the same as reserve(c). + void rehash(size_t c) { + // forces a reserve + reserve(c, true); + } + + // reserves space for the specified number of elements. Makes sure the old data fits. + // Exactly the same as rehash(c). Use rehash(0) to shrink to fit. + void reserve(size_t c) { + // reserve, but don't force rehash + reserve(c, false); + } + + // If possible reallocates the map to a smaller one. This frees the underlying table. + // Does not do anything if load_factor is too large for decreasing the table's size. + void compact() { + ROBIN_HOOD_TRACE(this) + auto newSize = InitialNumElements; + while (calcMaxNumElementsAllowed(newSize) < mNumElements && newSize != 0) { + newSize *= 2; + } + if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { + throwOverflowError(); + } + + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (newSize < mMask + 1) { + rehashPowerOfTwo(newSize, true); + } + } + + size_type size() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return mNumElements; + } + + size_type max_size() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return static_cast(-1); + } + + ROBIN_HOOD(NODISCARD) bool empty() const noexcept { + ROBIN_HOOD_TRACE(this) + return 0 == mNumElements; + } + + float max_load_factor() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return MaxLoadFactor100 / 100.0F; + } + + // Average number of elements per bucket. Since we allow only 1 per bucket + float load_factor() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return static_cast(size()) / static_cast(mMask + 1); + } + + ROBIN_HOOD(NODISCARD) size_t mask() const noexcept { + ROBIN_HOOD_TRACE(this) + return mMask; + } + + ROBIN_HOOD(NODISCARD) size_t calcMaxNumElementsAllowed(size_t maxElements) const noexcept { + if (ROBIN_HOOD_LIKELY(maxElements <= (std::numeric_limits::max)() / 100)) { + return maxElements * MaxLoadFactor100 / 100; + } + + // we might be a bit inprecise, but since maxElements is quite large that doesn't matter + return (maxElements / 100) * MaxLoadFactor100; + } + + ROBIN_HOOD(NODISCARD) size_t calcNumBytesInfo(size_t numElements) const noexcept { + // we add a uint64_t, which houses the sentinel (first byte) and padding so we can load + // 64bit types. + return numElements + sizeof(uint64_t); + } + + ROBIN_HOOD(NODISCARD) + size_t calcNumElementsWithBuffer(size_t numElements) const noexcept { + auto maxNumElementsAllowed = calcMaxNumElementsAllowed(numElements); + return numElements + (std::min)(maxNumElementsAllowed, (static_cast(0xFF))); + } + + // calculation only allowed for 2^n values + ROBIN_HOOD(NODISCARD) size_t calcNumBytesTotal(size_t numElements) const { +#if ROBIN_HOOD(BITNESS) == 64 + return numElements * sizeof(Node) + calcNumBytesInfo(numElements); +#else + // make sure we're doing 64bit operations, so we are at least safe against 32bit overflows. + auto const ne = static_cast(numElements); + auto const s = static_cast(sizeof(Node)); + auto const infos = static_cast(calcNumBytesInfo(numElements)); + + auto const total64 = ne * s + infos; + auto const total = static_cast(total64); + + if (ROBIN_HOOD_UNLIKELY(static_cast(total) != total64)) { + throwOverflowError(); + } + return total; +#endif + } + +private: + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, bool>::type has(const value_type& e) const { + ROBIN_HOOD_TRACE(this) + auto it = find(e.first); + return it != end() && it->second == e.second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, bool>::type has(const value_type& e) const { + ROBIN_HOOD_TRACE(this) + return find(e) != end(); + } + + void reserve(size_t c, bool forceRehash) { + ROBIN_HOOD_TRACE(this) + auto const minElementsAllowed = (std::max)(c, mNumElements); + auto newSize = InitialNumElements; + while (calcMaxNumElementsAllowed(newSize) < minElementsAllowed && newSize != 0) { + newSize *= 2; + } + if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { + throwOverflowError(); + } + + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (forceRehash || newSize > mMask + 1) { + rehashPowerOfTwo(newSize, false); + } + } + + // reserves space for at least the specified number of elements. + // only works if numBuckets if power of two + // True on success, false otherwise + void rehashPowerOfTwo(size_t numBuckets, bool forceFree) { + ROBIN_HOOD_TRACE(this) + + Node* const oldKeyVals = mKeyVals; + uint8_t const* const oldInfo = mInfo; + + const size_t oldMaxElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + + // resize operation: move stuff + initData(numBuckets); + if (oldMaxElementsWithBuffer > 1) { + for (size_t i = 0; i < oldMaxElementsWithBuffer; ++i) { + if (oldInfo[i] != 0) { + // might throw an exception, which is really bad since we are in the middle of + // moving stuff. + insert_move(std::move(oldKeyVals[i])); + // destroy the node but DON'T destroy the data. + oldKeyVals[i].~Node(); + } + } + + // this check is not necessary as it's guarded by the previous if, but it helps + // silence g++'s overeager "attempt to free a non-heap object 'map' + // [-Werror=free-nonheap-object]" warning. + if (oldKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + // don't destroy old data: put it into the pool instead + if (forceFree) { + std::free(oldKeyVals); + } else { + DataPool::addOrFree(oldKeyVals, calcNumBytesTotal(oldMaxElementsWithBuffer)); + } + } + } + } + + ROBIN_HOOD(NOINLINE) void throwOverflowError() const { +#if ROBIN_HOOD(HAS_EXCEPTIONS) + throw std::overflow_error("robin_hood::map overflow"); +#else + abort(); +#endif + } + + template + std::pair try_emplace_impl(OtherKey&& key, Args&&... args) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + template + std::pair insertOrAssignImpl(OtherKey&& key, Mapped&& obj) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + mKeyVals[idxAndState.first].getSecond() = std::forward(obj); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + void initData(size_t max_elements) { + mNumElements = 0; + mMask = max_elements - 1; + mMaxNumElementsAllowed = calcMaxNumElementsAllowed(max_elements); + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(max_elements); + + // malloc & zero mInfo. Faster than calloc everything. + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::calloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = reinterpret_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + std::memset(mInfo, 0, numBytesTotal - numElementsWithBuffer * sizeof(Node)); + + // set sentinel + mInfo[numElementsWithBuffer] = 1; + + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + enum class InsertionState { overflow_error, key_found, new_node, overwrite_node }; + + // Finds key, and if not already present prepares a spot where to pot the key & value. + // This potentially shifts nodes out of the way, updates mInfo and number of inserted + // elements, so the only operation left to do is create/assign a new node at that spot. + template + std::pair insertKeyPrepareEmptySpot(OtherKey&& key) { + for (int i = 0; i < 256; ++i) { + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + nextWhileLess(&info, &idx); + + // while we potentially have a match + while (info == mInfo[idx]) { + if (WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { + // key already exists, do NOT insert. + // see http://en.cppreference.com/w/cpp/container/unordered_map/insert + return std::make_pair(idx, InsertionState::key_found); + } + next(&info, &idx); + } + + // unlikely that this evaluates to true + if (ROBIN_HOOD_UNLIKELY(mNumElements >= mMaxNumElementsAllowed)) { + if (!increase_size()) { + return std::make_pair(size_t(0), InsertionState::overflow_error); + } + continue; + } + + // key not found, so we are now exactly where we want to insert it. + auto const insertion_idx = idx; + auto const insertion_info = info; + if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + + // find an empty spot + while (0 != mInfo[idx]) { + next(&info, &idx); + } + + if (idx != insertion_idx) { + shiftUp(idx, insertion_idx); + } + // put at empty spot + mInfo[insertion_idx] = static_cast(insertion_info); + ++mNumElements; + return std::make_pair(insertion_idx, idx == insertion_idx + ? InsertionState::new_node + : InsertionState::overwrite_node); + } + + // enough attempts failed, so finally give up. + return std::make_pair(size_t(0), InsertionState::overflow_error); + } + + bool try_increase_info() { + ROBIN_HOOD_LOG("mInfoInc=" << mInfoInc << ", numElements=" << mNumElements + << ", maxNumElementsAllowed=" + << calcMaxNumElementsAllowed(mMask + 1)) + if (mInfoInc <= 2) { + // need to be > 2 so that shift works (otherwise undefined behavior!) + return false; + } + // we got space left, try to make info smaller + mInfoInc = static_cast(mInfoInc >> 1U); + + // remove one bit of the hash, leaving more space for the distance info. + // This is extremely fast because we can operate on 8 bytes at once. + ++mInfoHashShift; + auto const numElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + + for (size_t i = 0; i < numElementsWithBuffer; i += 8) { + auto val = unaligned_load(mInfo + i); + val = (val >> 1U) & UINT64_C(0x7f7f7f7f7f7f7f7f); + std::memcpy(mInfo + i, &val, sizeof(val)); + } + // update sentinel, which might have been cleared out! + mInfo[numElementsWithBuffer] = 1; + + mMaxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1); + return true; + } + + // True if resize was possible, false otherwise + bool increase_size() { + // nothing allocated yet? just allocate InitialNumElements + if (0 == mMask) { + initData(InitialNumElements); + return true; + } + + auto const maxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1); + if (mNumElements < maxNumElementsAllowed && try_increase_info()) { + return true; + } + + ROBIN_HOOD_LOG("mNumElements=" << mNumElements << ", maxNumElementsAllowed=" + << maxNumElementsAllowed << ", load=" + << (static_cast(mNumElements) * 100.0 / + (static_cast(mMask) + 1))) + + if (mNumElements * 2 < calcMaxNumElementsAllowed(mMask + 1)) { + // we have to resize, even though there would still be plenty of space left! + // Try to rehash instead. Delete freed memory so we don't steadyily increase mem in case + // we have to rehash a few times + nextHashMultiplier(); + rehashPowerOfTwo(mMask + 1, true); + } else { + // we've reached the capacity of the map, so the hash seems to work nice. Keep using it. + rehashPowerOfTwo((mMask + 1) * 2, false); + } + return true; + } + + void nextHashMultiplier() { + // adding an *even* number, so that the multiplier will always stay odd. This is necessary + // so that the hash stays a mixing function (and thus doesn't have any information loss). + mHashMultiplier += UINT64_C(0xc4ceb9fe1a85ec54); + } + + void destroy() { + if (0 == mMask) { + // don't deallocate! + return; + } + + Destroyer::value>{} + .nodesDoNotDeallocate(*this); + + // This protection against not deleting mMask shouldn't be needed as it's sufficiently + // protected with the 0==mMask check, but I have this anyways because g++ 7 otherwise + // reports a compile error: attempt to free a non-heap object 'fm' + // [-Werror=free-nonheap-object] + if (mKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + ROBIN_HOOD_LOG("std::free") + std::free(mKeyVals); + } + } + + void init() noexcept { + mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); + mInfo = reinterpret_cast(&mMask); + mNumElements = 0; + mMask = 0; + mMaxNumElementsAllowed = 0; + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + // members are sorted so no padding occurs + uint64_t mHashMultiplier = UINT64_C(0xc4ceb9fe1a85ec53); // 8 byte 8 + Node* mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); // 8 byte 16 + uint8_t* mInfo = reinterpret_cast(&mMask); // 8 byte 24 + size_t mNumElements = 0; // 8 byte 32 + size_t mMask = 0; // 8 byte 40 + size_t mMaxNumElementsAllowed = 0; // 8 byte 48 + InfoType mInfoInc = InitialInfoInc; // 4 byte 52 + InfoType mInfoHashShift = InitialInfoHashShift; // 4 byte 56 + // 16 byte 56 if NodeAllocator +}; + +} // namespace detail + +// map + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_flat_map = detail::Table; + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_node_map = detail::Table; + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_map = + detail::Table) <= sizeof(size_t) * 6 && + std::is_nothrow_move_constructible>::value && + std::is_nothrow_move_assignable>::value, + MaxLoadFactor100, Key, T, Hash, KeyEqual>; + +// set + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_flat_set = detail::Table; + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_node_set = detail::Table; + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_set = detail::Table::value && + std::is_nothrow_move_assignable::value, + MaxLoadFactor100, Key, void, Hash, KeyEqual>; + +} // namespace robin_hood + +#endif diff --git a/src/backend/search_engine/index_builder/index_builder.cpp b/src/backend/search_engine/index_builder/index_builder.cpp index 5c45cbe..f15f567 100644 --- a/src/backend/search_engine/index_builder/index_builder.cpp +++ b/src/backend/search_engine/index_builder/index_builder.cpp @@ -1,780 +1,469 @@ -#include -#include -#include -#include -#include +#include #include +#include #include -#include -#include -#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include #include -#include -#include "libstemmer.h" - -// NOTE: build and read on same architecture (endianness, size of types) - -namespace fs = std::filesystem; - -// constants for memory size estimation (64 bit system) -const size_t MAP_NODE_OVERHEAD = 32; -const size_t VECTOR_OVERHEAD = 24; - -struct DocInfo { - std::string url; - std::string title; -}; - -struct ParsedDoc { - uint32_t doc_id; - std::string url; - std::string title; - std::string body; -}; +#include +#include +#include "include/robin_hood.h" +// not encoded as neglectably small class DocStoreWriter { private: - std::ofstream out_stream; - std::ofstream offset_stream; - uint64_t current_byte_offset; // offset where next doc will be written/read - uint32_t doc_count; + std::ofstream outStream; + std::ofstream offsetStream; + uint64_t currentByteOffset; // offset where next doc will be written/read + uint32_t docCount; public: void init(const std::string& filename_base) { - out_stream.open(filename_base + ".docstore", std::ios::binary | std::ios::out | std::ios::trunc); - offset_stream.open(filename_base + ".docstore_offsets", std::ios::binary | std::ios::out | std::ios::trunc); + outStream.open(filename_base + "/docstore.bin", std::ios::binary | std::ios::out | std::ios::trunc); + offsetStream.open(filename_base + "/docstore_offsets.bin", std::ios::binary | std::ios::out | std::ios::trunc); - current_byte_offset = 0; - doc_count = 0; + currentByteOffset = 0; + docCount = 0; - out_stream.write(reinterpret_cast(&doc_count), sizeof(doc_count)); - current_byte_offset += sizeof(doc_count); + outStream.write(reinterpret_cast(&docCount), sizeof(docCount)); + currentByteOffset += sizeof(docCount); } - void add_document(uint32_t doc_id, const std::string& url, const std::string& title) { + void addDocument(uint32_t docId, const std::string& url, const std::string& title) { /* - offset_stream: - 0 (offset of doc_id 0) - 18 (offset of doc_id 1) + offsetStream: + [0-3] docId = 42 + [4-11] offset = 0 (start of this doc in outStream) + + [12-15] docId = 105 + [16-23] offset = 18 (start of this doc in outStream) ... - out_stream: - [0-3] url_len = 5 + outStream: + [0-3] urlLen = 5 [4-8] 'a' '.' 'c' 'o' 'm' - [9-12] title_len = 5 + [9-12] titleLen = 5 [13-17] 'H' 'e' 'l' 'l' 'o' - [18-21] url_len = 11 + [18-21] urlLen = 11 [22-32] 'e' 'x' 'a' 'm' 'p' 'l' 'e' '.' 'o' 'r' 'g' - [33-34] title_len = 2 + [33-34] titleLen = 2 [35-36] 'H' 'i' */ - offset_stream.write(reinterpret_cast(¤t_byte_offset), sizeof(current_byte_offset)); - - uint32_t url_len = url.size(); - out_stream.write(reinterpret_cast(&url_len), sizeof(url_len)); - out_stream.write(url.data(), url_len); + offsetStream.write(reinterpret_cast(&docId), sizeof(docId)); + offsetStream.write(reinterpret_cast(¤tByteOffset), sizeof(currentByteOffset)); + + uint32_t urlLen = url.size(); + outStream.write(reinterpret_cast(&urlLen), sizeof(urlLen)); + outStream.write(url.data(), urlLen); - uint32_t title_len = title.size(); - out_stream.write(reinterpret_cast(&title_len), sizeof(title_len)); - out_stream.write(title.data(), title_len); + uint32_t titleLen = title.size(); + outStream.write(reinterpret_cast(&titleLen), sizeof(titleLen)); + outStream.write(title.data(), titleLen); // faster than tellp() - current_byte_offset += sizeof(uint32_t) + url_len + sizeof(uint32_t) + title_len; + currentByteOffset += sizeof(uint32_t) + urlLen + sizeof(uint32_t) + titleLen; - doc_count++; + docCount++; } void close() { - if (out_stream.is_open()) { + if (outStream.is_open()) { // write actual doc count at the beginning of the data file which was 0 before - out_stream.seekp(0); - out_stream.write(reinterpret_cast(&doc_count), sizeof(doc_count)); - out_stream.close(); - offset_stream.close(); + outStream.seekp(0); + outStream.write(reinterpret_cast(&docCount), sizeof(docCount)); + outStream.close(); + offsetStream.close(); } } }; -// NOTE: Snowball stemmer instance is not thread-safe -struct SnowballStemmer { - struct sb_stemmer* stemmer; - SnowballStemmer() { - stemmer = sb_stemmer_new("english", nullptr); - } - ~SnowballStemmer() { - sb_stemmer_delete(stemmer); - } - std::string stem(const std::string& word) { - const sb_symbol* stemmed = sb_stemmer_stem(stemmer, - reinterpret_cast(word.c_str()), word.size()); - int out_len = sb_stemmer_length(stemmer); - if (stemmed == nullptr || out_len <= 0) return std::string(); - return std::string(reinterpret_cast(stemmed), static_cast(out_len)); - } +struct Posting { + int docId; + std::vector positions; +}; + +static const robin_hood::unordered_flat_set STOP_WORDS = { + "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", + "into", "is", "it", "no", "not", "of", "on", "or", "such", "that", "the", + "their", "then", "there", "these", "they", "this", "to", "was", "will", "with" }; -SnowballStemmer stemmer; -struct PostingList { - std::vector postings; - std::vector term_frequencies; - std::vector> positions; - std::unordered_map skip_pointers; +class Tokenizer { +private: + struct sb_stemmer* stemmer; + std::string tokenBuffer; - size_t add_document_occurrences(uint32_t doc_id, const std::vector& new_positions) { - size_t memory_delta = 0; - - if (!postings.empty() && postings.back() == doc_id) { - // existing last doc: extend its positions vector - term_frequencies.back() += new_positions.size(); - auto& pos_vec = positions.back(); - - size_t old_cap = pos_vec.capacity(); - pos_vec.insert(pos_vec.end(), new_positions.begin(), new_positions.end()); - size_t new_cap = pos_vec.capacity(); - - memory_delta += (new_cap - old_cap) * sizeof(uint32_t); - } else { - // new doc entry - postings.push_back(doc_id); - term_frequencies.push_back(new_positions.size()); - - positions.emplace_back(); - auto& pos_vec = positions.back(); - pos_vec.reserve(new_positions.size()); - pos_vec.insert(pos_vec.end(), new_positions.begin(), new_positions.end()); - - memory_delta += sizeof(uint32_t) * 2; // doc_id + tf (stored elsewhere) - memory_delta += VECTOR_OVERHEAD; // vector structure overhead - memory_delta += pos_vec.capacity() * sizeof(uint32_t); +public: + Tokenizer() { + stemmer = sb_stemmer_new("english", "UTF_8"); + if (!stemmer) { + throw std::runtime_error("Failed to create English stemmer"); } - - return memory_delta; + tokenBuffer.reserve(64); } - - void build_skip_pointers() { - if (postings.empty()) return; - size_t skip_interval = static_cast(std::sqrt(postings.size())); - if (skip_interval < 2) return; - - skip_pointers.clear(); - for (size_t i = 0; i + skip_interval < postings.size(); i += skip_interval) { - skip_pointers[i] = i + skip_interval; - } + + ~Tokenizer() { + if (stemmer) sb_stemmer_delete(stemmer); } - static PostingList merge(const PostingList& a, const PostingList& b) { - PostingList result; - size_t i = 0, j = 0; + // non-copyable + Tokenizer(const Tokenizer&) = delete; + Tokenizer& operator=(const Tokenizer&) = delete; + + template + void tokenize(const char* text, size_t len, Callback&& callback) { + int position = 0; + size_t i = 0; - // reserve memory to avoid reallocations - result.postings.reserve(a.postings.size() + b.postings.size()); - result.term_frequencies.reserve(a.term_frequencies.size() + b.term_frequencies.size()); - result.positions.reserve(a.positions.size() + b.positions.size()); - - while (i < a.postings.size() || j < b.postings.size()) { - - uint32_t doc_id_a = (i < a.postings.size()) ? a.postings[i] : UINT32_MAX; - uint32_t doc_id_b = (j < b.postings.size()) ? b.postings[j] : UINT32_MAX; - - if (doc_id_a < doc_id_b) { - result.postings.push_back(doc_id_a); - result.term_frequencies.push_back(a.term_frequencies[i]); - result.positions.push_back(a.positions[i]); + while (i < len) { + while (i < len && !std::isalpha(static_cast(text[i]))) { i++; - } else if (doc_id_b < doc_id_a) { - result.postings.push_back(doc_id_b); - result.term_frequencies.push_back(b.term_frequencies[j]); - result.positions.push_back(b.positions[j]); - j++; - } else if (doc_id_a != UINT32_MAX) { - uint32_t doc_id = doc_id_a; - result.postings.push_back(doc_id); - - result.term_frequencies.push_back(a.term_frequencies[i] + b.term_frequencies[j]); - - std::vector pos_a = a.positions[i]; - const auto& pos_b = b.positions[j]; - pos_a.insert(pos_a.end(), pos_b.begin(), pos_b.end()); - result.positions.push_back(std::move(pos_a)); - + } + if (i >= len) break; + + tokenBuffer.clear(); + while (i < len && std::isalpha(static_cast(text[i]))) { + tokenBuffer.push_back(std::tolower(static_cast(text[i]))); i++; - j++; - } else { - // both UINT32_MAX -> end - break; } + + if (tokenBuffer.empty()) continue; + + if (STOP_WORDS.count(tokenBuffer)) { + position++; + continue; + } + + const sb_symbol* stemmed = sb_stemmer_stem( + stemmer, + reinterpret_cast(tokenBuffer.data()), + tokenBuffer.size() + ); + int stemLen = sb_stemmer_length(stemmer); + + std::string term(reinterpret_cast(stemmed), stemLen); + + callback(std::move(term), position); + position++; } - - return result; - } - - // estimate total memory size used for dumping logic - size_t memory_size() const { - size_t size = 0; - - size += postings.capacity() * sizeof(uint32_t); - size += term_frequencies.capacity() * sizeof(uint32_t); - - size += positions.capacity() * VECTOR_OVERHEAD; - - for (const auto& pos : positions) { - size += pos.capacity() * sizeof(uint32_t); - } - - size += skip_pointers.size() * (MAP_NODE_OVERHEAD + sizeof(uint32_t) * 2); - - return size; } }; -class MMapReader { -public: - const char* data; - size_t size; - int fd; - - MMapReader(const std::string& filename) { - fd = open(filename.c_str(), O_RDONLY); - if (fd == -1) throw std::runtime_error("Could not open file for mmap"); - - struct stat sb; - if (fstat(fd, &sb) == -1) throw std::runtime_error("Could not stat file"); - size = sb.st_size; - - if (size == 0) { - data = nullptr; - return; +void spillToDisk( + robin_hood::unordered_flat_map>& termPostings, + const robin_hood::unordered_flat_map& termDictionary, + const std::string& postingsFile, + const std::string& dictFile) +{ + std::vector> sortedTerms; + sortedTerms.reserve(termDictionary.size()); + for (const auto& kv : termDictionary) + sortedTerms.emplace_back(kv.first, kv.second); + + // sort for more efficient merging of the spilled files later + std::sort(sortedTerms.begin(), sortedTerms.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + + std::ofstream postOut(postingsFile, std::ios::binary); + std::ofstream dictOut(dictFile, std::ios::binary); + if (!postOut || !dictOut) + throw std::runtime_error("Failed to open output files"); + + static char postBuffer[8 * 1024 * 1024]; // 8MB + static char dictBuffer[8 * 1024 * 1024]; + postOut.rdbuf()->pubsetbuf(postBuffer, sizeof(postBuffer)); + dictOut.rdbuf()->pubsetbuf(dictBuffer, sizeof(dictBuffer)); + + uint64_t offset = 0; + + for (const auto& [term, termId] : sortedTerms) { + auto it = termPostings.find(termId); + if (it == termPostings.end()) + continue; + + // sort for search and union of posting lists + std::vector& postings = it->second; + std::sort(postings.begin(), postings.end(), + [](const Posting& a, const Posting& b) { return a.docId < b.docId; }); + + uint64_t startOffset = offset; + + uint32_t docFreq = postings.size(); + + for (const auto& posting : postings) { + // write docId + postOut.write(reinterpret_cast(&posting.docId), sizeof(uint32_t)); + offset += sizeof(uint32_t); + + // write posCount + uint32_t posCount = posting.positions.size(); + postOut.write(reinterpret_cast(&posCount), sizeof(uint32_t)); + offset += sizeof(uint32_t); + + // write positions + for (uint32_t pos : posting.positions) { + postOut.write(reinterpret_cast(&pos), sizeof(uint32_t)); + offset += sizeof(uint32_t); + } } - void* mapped = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0); - if (mapped == MAP_FAILED) throw std::runtime_error("mmap failed"); + /* + Example: apple + [0-3]: 5 (termLen, 4 bytes) + [4-8]: a p p l e + [9-16]: startOffset (8 bytes) where posting list starts in postings file + [17-20]: docFreq (4 bytes) - data = static_cast(mapped); - madvise(mapped, size, MADV_SEQUENTIAL); + Term frequency will be in the merge step + */ + uint32_t termLen = term.size(); + dictOut.write(reinterpret_cast(&termLen), sizeof(termLen)); + dictOut.write(term.data(), termLen); + dictOut.write(reinterpret_cast(&startOffset), sizeof(startOffset)); + dictOut.write(reinterpret_cast(&docFreq), sizeof(docFreq)); } - ~MMapReader() { - if (data) munmap(const_cast(data), size); - if (fd != -1) close(fd); + postOut.close(); + dictOut.close(); +} + +int main(int argc, char* argv[]) { + if (argc < 2 || argc > 3) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + std::cout << "Starting index building with memory limit: " << argv[1] << " MB" << std::endl; + int32_t maxDocs = -1; + if (argc == 3) { + maxDocs = static_cast(std::stoll(argv[2])); } - - MMapReader(const MMapReader&) = delete; - MMapReader& operator=(const MMapReader&) = delete; -}; -template -T read_val(const char*& ptr) { - T val; - std::memcpy(&val, ptr, sizeof(T)); - ptr += sizeof(T); - return val; -} + // high allocator and fragmentation overhead, often making real memory usage + // 2–4× larger than the raw data size + // --> use only 20% of the given limit for raw data + uint64_t MaxMb = std::stoull(argv[1]); + uint64_t MEMORYLIMIT = static_cast(MaxMb * 1024ull * 1024ull * 0.20); -uint64_t write_posting_list(std::ofstream& out, const PostingList& pl, bool with_skip_pointers = false) { - /* - with_skip_pointers: we don't need skip pointers for temp spilling to disk, only for final index - - Example: - pl: - postings = [3, 7] - term_frequencies = {3: 4, 7: 1} - positions = { - 3: [0, 5, 9, 20], - 7: [13] - } + using namespace std::chrono; + auto start = high_resolution_clock::now(); + std::filesystem::path exePath = std::filesystem::absolute(argv[0]).parent_path(); + std::filesystem::path projectRoot = exePath.parent_path().parent_path(); - In binary file: - count_docs = 2 + std::string dataDir = "/data"; + + const char* test_env = std::getenv("ENV"); // for integration tests, test with controlled and small dataset in test_data + if (test_env && std::string(test_env) == "TEST_ENV") { + std::cout << "TEST ENVIRONMENT, building index with test data." << std::endl; + dataDir = "/test_data"; + } - 3 (first doc_id) - 4 (tf for doc_id 3) - 4 (pos_count for doc_id 3) - [0, 5, 9, 20] (positions for doc_id 3) + std::string projectDir = projectRoot.string(); + std::string partialIndexPostingsDir = projectDir + dataDir + "/partial_indices/postings"; + std::string partialIndexDictDir = projectDir + dataDir + "/partial_indices/dictionaries"; + std::string metadataDir = projectDir + dataDir + "/index"; - 7 (second doc_id) - 1 (tf for doc_id 7) - 1 (pos_count for doc_id 7) - [13] (positions for doc_id 7) - */ + std::filesystem::create_directories(partialIndexPostingsDir); + std::filesystem::create_directories(partialIndexDictDir); + std::filesystem::create_directories(metadataDir); - uint64_t offset = out.tellp(); - uint32_t count_docs = pl.postings.size(); - out.write(reinterpret_cast(&count_docs), sizeof(count_docs)); - - for (size_t i = 0; i < count_docs; ++i) { - uint32_t doc_id = pl.postings[i]; - uint32_t tf = pl.term_frequencies[i]; - const auto& pos = pl.positions[i]; - - out.write(reinterpret_cast(&doc_id), sizeof(doc_id)); - out.write(reinterpret_cast(&tf), sizeof(tf)); - - uint32_t pos_count = pos.size(); - out.write(reinterpret_cast(&pos_count), sizeof(pos_count)); - out.write(reinterpret_cast(pos.data()), pos_count * sizeof(uint32_t)); - } + Tokenizer tokenizer; + std::ifstream infile(projectDir + dataDir + "/msmarco-docs.tsv"); - if (with_skip_pointers) { - uint32_t skip_count = pl.skip_pointers.size(); - out.write(reinterpret_cast(&skip_count), sizeof(skip_count)); - for (const auto& [from_idx, to_idx] : pl.skip_pointers) { - out.write(reinterpret_cast(&from_idx), sizeof(from_idx)); - out.write(reinterpret_cast(&to_idx), sizeof(to_idx)); - } + if (!infile.is_open()) { + std::cerr << "Failed to open input file\n"; + return 1; } - - return offset; -} -PostingList read_posting_list(const char*& ptr) { - PostingList pl; - - uint32_t count_docs = read_val(ptr); - pl.postings.reserve(count_docs); - pl.term_frequencies.reserve(count_docs); - pl.positions.reserve(count_docs); - - for (uint32_t i = 0; i < count_docs; i++) { - uint32_t doc_id = read_val(ptr); - uint32_t tf = read_val(ptr); - uint32_t pos_count = read_val(ptr); - - pl.postings.push_back(doc_id); - pl.term_frequencies.push_back(tf); - - std::vector positions(pos_count); - std::memcpy(positions.data(), ptr, pos_count * sizeof(uint32_t)); - ptr += pos_count * sizeof(uint32_t); - - pl.positions.push_back(std::move(positions)); - } - - return pl; -} + DocStoreWriter docStore; + std::string docstoreBase = projectDir + dataDir + "/docstore"; + std::filesystem::create_directories(docstoreBase); + docStore.init(docstoreBase); + + robin_hood::unordered_flat_map termDictionary; + robin_hood::unordered_flat_map> termPostings; + + // reserve space to avoid frequent rehashes, assuming 500k unique terms + termDictionary.reserve(500'000); + termPostings.reserve(500'000); + + std::vector docLengths; + docLengths.reserve(3'300'000); // ~3.2M documents in corpus + + std::string line; + line.reserve(16384); + size_t lineNumber = 0; + uint32_t partialIndexesCount = 0; + size_t memoryBytes = 0; + + while (std::getline(infile, line)) { + lineNumber++; + + if (memoryBytes > MEMORYLIMIT) { + std::string postingsFile = partialIndexPostingsDir + "/postings_" + + std::to_string(partialIndexesCount) + ".bin"; + std::string dictFile = partialIndexDictDir + "/dictionary_" + + std::to_string(partialIndexesCount) + ".bin"; + try { + spillToDisk(termPostings, termDictionary, postingsFile, dictFile); + } catch (const std::exception& e) { + std::cerr << "Error writing index: " << e.what() << std::endl; + return 1; + } + termPostings.clear(); + termDictionary.clear(); + partialIndexesCount++; + memoryBytes = 0; + auto elapsed = duration(high_resolution_clock::now() - start).count(); + std::cout << "[Partial Index #" << partialIndexesCount << "] " + << "Lines processed: " << lineNumber + << " Time: " << elapsed << "s\n"; + } -struct MergeState { - std::string term; - const char* current_ptr; - int file_index; + size_t pos1 = line.find('\t'); + size_t pos2 = line.find('\t', pos1 + 1); + size_t pos3 = line.find('\t', pos2 + 1); + if (pos3 == std::string::npos) + continue; + + // parse docId + int docId = -1; + if (pos1 >= 2 && line[0] == 'D') { + docId = 0; + for (size_t i = 1; i < pos1; i++) { + char c = line[i]; + if (c >= '0' && c <= '9') + docId = docId * 10 + (c - '0'); + else { + docId = -1; + break; + } + } + } + if (docId < 0) continue; - // std::priority_queue is max-heap by default, so we invert the comparison - bool operator>(const MergeState& other) const { - return term > other.term; - } -}; + std::string url = line.substr(pos1 + 1, pos2 - pos1 - 1); + std::string title = line.substr(pos2 + 1, pos3 - pos2 - 1); -class InvertedIndexBuilder { -private: - std::unordered_map term_to_index; - std::vector posting_lists; - size_t current_memory; - size_t memory_limit; - std::vector spilled_files; - int spill_counter; - std::string temp_dir; - DocStoreWriter doc_store; - std::string temp_docstore_base; + docStore.addDocument(docId, url, title); -public: - InvertedIndexBuilder(size_t mem_limit_mb = 1024) - : current_memory(0), - memory_limit(mem_limit_mb * 1024 * 1024), - spill_counter(0), - temp_dir("temp_index") { - fs::create_directories(temp_dir); - // writing docstore to temp location instead of final location in case of crashes - temp_docstore_base = temp_dir + "/temp_docstore"; - doc_store.init(temp_docstore_base); - } - - void add_document(uint32_t doc_id, const std::vector& tokens, - const std::string& url = "", const std::string& title = "") { - std::unordered_map> term_positions; - for (uint32_t pos = 0; pos < tokens.size(); pos++) { - term_positions[tokens[pos]].push_back(pos); + // ensure docLengths vector is large enough + if (static_cast(docId) >= docLengths.size()) { + docLengths.resize(docId + 1, 0); } + uint32_t docTermCount = 0; + + // tokenize title + content directly + // title is from pos2+1 to pos3, content is from pos3+1 to end + const char* titleStart = line.data() + pos2 + 1; + size_t titleLen = pos3 - pos2 - 1; + const char* contentStart = line.data() + pos3 + 1; + size_t contentLen = line.size() - pos3 - 1; - for (const auto& [term, positions] : term_positions) { - uint32_t pl_index; - bool is_new_term = false; + // process title + tokenizer.tokenize(titleStart, titleLen, [&](std::string&& term, int position) { + docTermCount++; - auto it = term_to_index.find(term); - if (it == term_to_index.end()) { - is_new_term = true; - pl_index = posting_lists.size(); - term_to_index[term] = pl_index; - posting_lists.emplace_back(); + uint32_t termId; + auto it = termDictionary.find(term); + if (it == termDictionary.end()) { + termId = termDictionary.size(); + memoryBytes += sizeof(uint32_t) + term.size(); + termDictionary.emplace(std::move(term), termId); } else { - pl_index = it->second; - } - - auto& pl = posting_lists[pl_index]; - size_t bytes_added = pl.add_document_occurrences(doc_id, positions); - - if (is_new_term) { - bytes_added += MAP_NODE_OVERHEAD + term.size() + sizeof(PostingList); + termId = it->second; } - - current_memory += bytes_added; - } - - if (current_memory > memory_limit) { - spill_to_disk(); - } - - doc_store.add_document(doc_id, url, title); - } - - void spill_to_disk() { - /* - * Example binary file layout for spill_to_disk() with 2 terms: "apple" and "banana" - * - * Example: - * - Term "apple" appears in doc_id 3 (tf=4, positions=[0,5]) and doc_id 7 (tf=2, positions=[10]) - * - Term "banana" appears in doc_id 2 (tf=1, positions=[1]) - * - * Byte-level layout (offsets in bytes): - * - * [0-3] : term_count = 2 (uint32_t) - * - * Term 1: "apple" - * [4-7] : term_len = 5 (uint32_t) - * [8-12] : 'a' 'p' 'p' 'l' 'e' (5 bytes) - * [13-16] : posting_count = 2 (uint32_t) - * Doc 1: - * [17-20] : doc_id = 3 - * [21-24] : tf = 4 - * [25-28] : pos_count = 2 - * [29-36] : positions = 0, 5 (2 * 4 bytes) - * Doc 2: - * [37-40] : doc_id = 7 - * [41-44] : tf = 2 - * [45-48] : pos_count = 1 - * [49-52] : positions = 10 - * [53-56] : skip_count = 0 (no skip pointers) - * - * Term 2: "banana" - * [57-60] : term_len = 6 (uint32_t) - * [61-66] : 'b' 'a' 'n' 'a' 'n' 'a' (6 bytes) - * [67-70] : posting_count = 1 (uint32_t) - * Doc 1: - * [71-74] : doc_id = 2 - * [75-78] : tf = 1 - * [79-82] : pos_count = 1 - * [83-86] : positions = 1 - * [87-90] : skip_count = 0 - */ - std::string filename = temp_dir + "/spill_" + std::to_string(spill_counter++) + ".bin"; - std::ofstream out(filename, std::ios::binary); - - std::vector> sorted_terms; - sorted_terms.reserve(term_to_index.size()); - - for (const auto& [term, idx] : term_to_index) { - sorted_terms.emplace_back(term, idx); - } - - std::sort(sorted_terms.begin(), sorted_terms.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); - - uint32_t term_count = sorted_terms.size(); - out.write(reinterpret_cast(&term_count), sizeof(term_count)); - - for (const auto& [term, idx] : sorted_terms) { - uint32_t term_len = term.size(); - out.write(reinterpret_cast(&term_len), sizeof(term_len)); - out.write(term.data(), term_len); - - const PostingList& pl = posting_lists[idx]; - write_posting_list(out, pl); - } - - out.close(); - spilled_files.push_back(filename); - - std::cout << "Spilled " << term_to_index.size() << " terms to " << filename - << " (" << current_memory / (1024*1024) << " MB)" << std::endl; - - term_to_index.clear(); - posting_lists.clear(); - current_memory = 0; - } - - void finalize(const std::string& output_file_base) { - if (!term_to_index.empty()) { - spill_to_disk(); - } - - doc_store.close(); - - // move docstore files to final location - if (fs::exists(output_file_base + ".docstore")) fs::remove(output_file_base + ".docstore"); - if (fs::exists(output_file_base + ".docstore_offsets")) fs::remove(output_file_base + ".docstore_offsets"); - fs::rename(temp_docstore_base + ".docstore", output_file_base + ".docstore"); - fs::rename(temp_docstore_base + ".docstore_offsets", output_file_base + ".docstore_offsets"); - - std::cout << "Merging " << spilled_files.size() << " spilled files using Priority Queue..." << std::endl; - auto merge_start = std::chrono::high_resolution_clock::now(); - - std::vector> readers; - std::vector file_ptrs; // track current position per file - - // queue is max-heap by default, implement min-heap by inverting comparison with MergeState::operator> - std::priority_queue, std::greater> merge_queue; - - // open all temp files and push the FIRST term of each file into the queue - // queue contains term, offset in file, file index - // result: smallest term per file, so if e.g. "and" is smallest for 8/10 files, "or" for 2 files, - // we have 8 entries with "and" and 2 with "or" in the queue - for (size_t i = 0; i < spilled_files.size(); ++i) { - readers.push_back(std::make_unique(spilled_files[i])); - - const char* ptr = readers.back()->data; // start of file - const char* end = ptr + readers.back()->size; // end of file - - // every spill file starts with uint32_t term_count - // --> check that not empty - if (ptr + sizeof(uint32_t) <= end) { - uint32_t term_count = read_val(ptr); - - if (term_count > 0 && ptr + sizeof(uint32_t) <= end) { - // read first term - uint32_t term_len = read_val(ptr); - if (ptr + term_len <= end) { - std::string term(ptr, term_len); - ptr += term_len; - - merge_queue.push({term, ptr, (int)i}); - file_ptrs.push_back(ptr); - continue; - } - } + auto& postings = termPostings[termId]; + if (postings.empty() || postings.back().docId != docId) { + postings.push_back({docId, {}}); + postings.back().positions.reserve(8); + postings.back().positions.push_back(position); + memoryBytes += sizeof(Posting) + sizeof(int); + } else { + postings.back().positions.push_back(position); + memoryBytes += sizeof(int); } - - // if file is empty or malformed, track it as exhausted - file_ptrs.push_back(nullptr); - } - - // merging process - std::ofstream postings_out(output_file_base + ".postinglists", std::ios::binary); - std::ofstream index_out(output_file_base + ".index", std::ios::binary); + }); - uint64_t term_counter = 0; - - while (!merge_queue.empty()) { - MergeState min_state = merge_queue.top(); - merge_queue.pop(); - std::string min_term = min_state.term; + // process content (positions continue from title) + tokenizer.tokenize(contentStart, contentLen, [&](std::string&& term, int position) { + docTermCount++; - term_counter++; - if (term_counter % 100000 == 0) { - auto now = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(now - merge_start).count(); - std::cout << "Merged " << term_counter << " terms..." - << " | elapsed time: " << elapsed << " s" << std::endl; + uint32_t termId; + auto it = termDictionary.find(term); + if (it == termDictionary.end()) { + termId = termDictionary.size(); + memoryBytes += sizeof(uint32_t) + term.size(); + termDictionary.emplace(std::move(term), termId); + } else { + termId = it->second; } - - PostingList merged_pl; - - do { - int file_idx = min_state.file_index; - - const char* ptr = min_state.current_ptr; - - PostingList pl = read_posting_list(ptr); - - if (merged_pl.postings.empty()) merged_pl = std::move(pl); - else merged_pl = PostingList::merge(merged_pl, pl); - - const char* end = readers[file_idx]->data + readers[file_idx]->size; - - // after every postinglist comes next term length (uint32_t) + term if there is a next term - if (ptr + sizeof(uint32_t) <= end) { - uint32_t term_len; - std::memcpy(&term_len, ptr, sizeof(uint32_t)); - - if (ptr + sizeof(uint32_t) + term_len <= end) { - - ptr += sizeof(uint32_t); - std::string next_term(ptr, term_len); - ptr += term_len; - - merge_queue.push({next_term, ptr, file_idx}); - } - } - - if (merge_queue.empty() || merge_queue.top().term != min_term) break; - - min_state = merge_queue.top(); - merge_queue.pop(); - } while (true); - - merged_pl.build_skip_pointers(); - uint64_t offset = write_posting_list(postings_out, merged_pl, true); - - uint32_t term_len = min_term.size(); - index_out.write(reinterpret_cast(&term_len), sizeof(term_len)); - index_out.write(min_term.data(), term_len); - index_out.write(reinterpret_cast(&offset), sizeof(offset)); - } - - postings_out.close(); - index_out.close(); - readers.clear(); - - auto merge_end = std::chrono::high_resolution_clock::now(); - auto total_elapsed = std::chrono::duration_cast(merge_end - merge_start).count(); - std::cout << "Merging finished. Total terms: " << term_counter - << ", elapsed time: " << total_elapsed << "s" << std::endl; + auto& postings = termPostings[termId]; + if (postings.empty() || postings.back().docId != docId) { + postings.push_back({docId, {}}); + postings.back().positions.reserve(8); + postings.back().positions.push_back(position); + memoryBytes += sizeof(Posting) + sizeof(int); + } else { + postings.back().positions.push_back(position); + memoryBytes += sizeof(int); + } + }); - std::cout << "Cleaning up temporary files..." << std::endl; - for (const auto& filename : spilled_files) fs::remove(filename); - fs::remove(temp_dir); - std::cout << "Index built successfully: " << output_file_base << std::endl; + docLengths[docId] = docTermCount; + + if (maxDocs != -1 && lineNumber >= maxDocs) break; } -}; -std::vector tokenize(const std::string& text) { - std::vector tokens; - std::string token; + // final flush if remaining data + std::string postingsFile = partialIndexPostingsDir + "/postings_final.bin"; + std::string dictFile = partialIndexDictDir + "/dictionary_final.bin"; + spillToDisk(termPostings, termDictionary, postingsFile, dictFile); + + docStore.close(); - for (char c : text) { - if (std::isalnum(c)) { - token += std::tolower(c); - } else if (!token.empty()) { - tokens.push_back(stemmer.stem(token)); - token.clear(); + // calculate statistics and write metadata file + uint64_t totalTerms = 0; + uint32_t numDocs = 0; + for (size_t i = 0; i < docLengths.size(); i++) { + if (docLengths[i] > 0) { + totalTerms += docLengths[i]; + numDocs++; } } - if (!token.empty()) { - tokens.push_back(stemmer.stem(token)); - } - return tokens; -} - -uint32_t next_doc_id = 0; - -ParsedDoc parse_line(const std::string& line) { - size_t first_tab = line.find('\t'); - size_t second_tab = line.find('\t', first_tab + 1); - size_t third_tab = line.find('\t', second_tab + 1); - - ParsedDoc doc; - std::string docid_str = line.substr(0, first_tab); - doc.doc_id = next_doc_id++; - doc.url = line.substr(first_tab + 1, second_tab - first_tab - 1); - doc.title = line.substr(second_tab + 1, third_tab - second_tab - 1); - doc.body = line.substr(third_tab + 1); - return doc; -} + double avgDocLength = numDocs > 0 ? static_cast(totalTerms) / numDocs : 0.0; -int main(int argc, char* argv[]) { - if (argc < 2) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::string metadataFile = metadataDir + "/metadata.bin"; + std::ofstream metaOut(metadataFile, std::ios::binary); + if (!metaOut) { + std::cerr << "Failed to open metadata file for writing\n"; return 1; } - - std::cout << "Starting index building with memory limit: " << argv[1] << " MB" << std::endl; - - auto start_total = std::chrono::high_resolution_clock::now(); - - std::filesystem::path exe_path = std::filesystem::absolute(argv[0]).parent_path(); - std::filesystem::path project_root = exe_path.parent_path().parent_path(); - std::string data_dir_str = "data"; + // write header: numDocs, avgDocLength + metaOut.write(reinterpret_cast(&numDocs), sizeof(numDocs)); + metaOut.write(reinterpret_cast(&avgDocLength), sizeof(avgDocLength)); - const char* test_env = std::getenv("ENV"); // for integration tests, test with controlled and small dataset in test_data - if (test_env && std::string(test_env) == "TEST_ENV") { - std::cout << "TEST ENVIRONMENT, building index with test data." << std::endl; - data_dir_str = "test_data"; - } - - std::filesystem::path input_file = project_root / data_dir_str / "msmarco.tsv.gz"; - std::filesystem::path output_dir = project_root.parent_path() / "index" / "bin"; // put in parallel directory index/ where python code expects it - - InvertedIndexBuilder builder(std::stoull(argv[1])); - - gzFile in = gzopen(input_file.string().c_str(), "rb"); - if (!in) { - std::cerr << "Failed to open msmarco.tsv.gz" << std::endl; - return 1; - } - - const size_t CHUNK_SIZE = 4 * 1024 * 1024; - std::vector decompressed_data(CHUNK_SIZE + 1, 0); - std::string line_buffer; - - uint32_t doc_count = 0; - auto start_read = std::chrono::high_resolution_clock::now(); - auto start_chunk = start_read; - - uint32_t max_docs = -1; // set to a positive number for testing with limited docs - while (true) { - int bytes_read = gzread(in, decompressed_data.data(), CHUNK_SIZE); - if (bytes_read <= 0) break; - - decompressed_data[bytes_read] = 0; // null terminator - line_buffer.append(decompressed_data.data(), bytes_read); - - size_t start = 0; - for (size_t pos; (pos = line_buffer.find('\n', start)) != std::string::npos; start = pos + 1) { - std::string_view line(&line_buffer[start], pos - start); - while (!line.empty() && line.back() == '\r') line.remove_suffix(1); - - ParsedDoc doc = parse_line(std::string(line)); - std::vector tokens = tokenize(doc.body); - builder.add_document(doc.doc_id, tokens, doc.url, doc.title); - - doc_count++; - if (doc_count % 10000 == 0) { - auto now = std::chrono::high_resolution_clock::now(); - double chunk_sec = std::chrono::duration(now - start_chunk).count(); - double total_sec = std::chrono::duration(now - start_read).count(); - double docs_per_sec = doc_count / total_sec; - std::cout << "Processed " << doc_count << " docs | " - << "Chunk: " << chunk_sec << "s, Total: " << total_sec << "s | " - << "Speed: " << (int)docs_per_sec << " docs/s" << std::endl; - start_chunk = now; - } - if (max_docs != -1 && doc_count >= max_docs) break; + // write document lengths array (only non-zero entries with their docIds) + for (size_t docId = 0; docId < docLengths.size(); docId++) { + if (docLengths[docId] > 0) { + uint32_t id = static_cast(docId); + uint32_t len = docLengths[docId]; + metaOut.write(reinterpret_cast(&id), sizeof(id)); + metaOut.write(reinterpret_cast(&len), sizeof(len)); } - - line_buffer = line_buffer.substr(start); - if (max_docs != -1 && doc_count >= max_docs) break; } + metaOut.close(); - // process last line if no \n at end - if (doc_count < max_docs && !line_buffer.empty()) { - ParsedDoc doc = parse_line(line_buffer); - std::vector tokens = tokenize(doc.body); - builder.add_document(doc.doc_id, tokens, doc.url, doc.title); - } - - gzclose(in); - - auto end_read = std::chrono::high_resolution_clock::now(); - std::cout << "Finished reading. Time: " - << std::chrono::duration_cast(end_read - start_read).count() - << "s" << std::endl; - - std::filesystem::create_directories(output_dir); - - builder.finalize((output_dir / "inverted_index").string()); - - auto end_total = std::chrono::high_resolution_clock::now(); - std::cout << "Total time: " - << std::chrono::duration_cast(end_total - start_total).count() - << "s" << std::endl; + std::cout << "Metadata written: " << numDocs << " documents, avg length: " << avgDocLength << std::endl; + double totalTime = duration(high_resolution_clock::now() - start).count(); + std::cout << "Indexing completed in " << totalTime << " seconds.\n"; + std::cout << "Total lines processed: " << lineNumber << std::endl; return 0; } \ No newline at end of file diff --git a/src/backend/search_engine/index_builder/merge_partial_indices.cpp b/src/backend/search_engine/index_builder/merge_partial_indices.cpp new file mode 100644 index 0000000..919f993 --- /dev/null +++ b/src/backend/search_engine/index_builder/merge_partial_indices.cpp @@ -0,0 +1,341 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +struct DictEntry { + std::string term; + uint64_t offset; + uint32_t docFreq; + uint64_t postingSize; // size of the posting in bytes +}; + +struct PostingEntry { + uint32_t docId; + std::vector positions; +}; + +std::vector readDictionary(const std::string& dictFile, const std::string& postingsFile) { + std::vector entries; + + std::ifstream dictIn(dictFile, std::ios::binary); + if (!dictIn) { + std::cerr << "Failed to open dictionary file: " << dictFile << std::endl; + return entries; + } + + // get the size of the postings file to determine the last posting's size + std::ifstream postIn(postingsFile, std::ios::binary | std::ios::ate); + uint64_t postingsFileSize = postIn.tellg(); + postIn.close(); + + // read all dictionary entries + while (dictIn.peek() != EOF) { + DictEntry entry; + + // read term length + uint32_t termLen; + dictIn.read(reinterpret_cast(&termLen), sizeof(termLen)); + if (!dictIn) break; + + // read term + entry.term.resize(termLen); + dictIn.read(&entry.term[0], termLen); + if (!dictIn) break; + + // read offset + dictIn.read(reinterpret_cast(&entry.offset), sizeof(entry.offset)); + if (!dictIn) break; + + // read docFreq + dictIn.read(reinterpret_cast(&entry.docFreq), sizeof(entry.docFreq)); + if (!dictIn) break; + + entries.push_back(entry); + } + + // calculate posting sizes using next entry's offset + for (size_t i = 0; i < entries.size(); i++) { + if (i + 1 < entries.size()) { + entries[i].postingSize = entries[i + 1].offset - entries[i].offset; + } else { + entries[i].postingSize = postingsFileSize - entries[i].offset; + } + } + + return entries; +} + +std::vector readAndParsePosting(const std::string& postingsFile, uint64_t offset, uint64_t size) { + std::vector entries; + + std::ifstream postIn(postingsFile, std::ios::binary); + if (!postIn) { + std::cerr << "Failed to open postings file: " << postingsFile << std::endl; + return entries; + } + + postIn.seekg(offset); + + uint64_t bytesRead = 0; + while (bytesRead < size) { + PostingEntry entry; + + // read docId + postIn.read(reinterpret_cast(&entry.docId), sizeof(entry.docId)); + bytesRead += sizeof(entry.docId); + + // read posCount + uint32_t posCount; + postIn.read(reinterpret_cast(&posCount), sizeof(posCount)); + bytesRead += sizeof(posCount); + + // read positions + entry.positions.resize(posCount); + postIn.read(reinterpret_cast(entry.positions.data()), posCount * sizeof(uint32_t)); + bytesRead += posCount * sizeof(uint32_t); + + entries.push_back(entry); + } + + return entries; +} + +void writePosting(std::ofstream& out, const std::vector& entries) { + for (const auto& entry : entries) { + // write docId + out.write(reinterpret_cast(&entry.docId), sizeof(entry.docId)); + + // write posCount + uint32_t posCount = entry.positions.size(); + out.write(reinterpret_cast(&posCount), sizeof(posCount)); + + // write positions + out.write(reinterpret_cast(entry.positions.data()), + posCount * sizeof(uint32_t)); + } +} + +std::vector mergePostings(const std::vector>& allPostings) { + // use a map to merge postings by docId (automatically sorted) + std::map> mergedMap; + + for (const auto& postings : allPostings) { + for (const auto& entry : postings) { + auto& positions = mergedMap[entry.docId]; + positions.insert(positions.end(), entry.positions.begin(), entry.positions.end()); + } + } + + // convert map to vector and sort positions within each document + std::vector result; + result.reserve(mergedMap.size()); + + for (auto& [docId, positions] : mergedMap) { + // sort positions for this document + std::sort(positions.begin(), positions.end()); + + PostingEntry entry; + entry.docId = docId; + entry.positions = std::move(positions); + result.push_back(std::move(entry)); + } + + return result; +} + +/* +Final Format: + +offset file: + term_len (4 bytes) + term (term_len bytes) + offset (8 bytes) + docFreq (4 bytes) + +posting file: + docId (4 bytes) + posCount (4 bytes) + positions (posCount * 4 bytes) +*/ +int main(int argc, char* argv[]) { + using namespace std::chrono; + auto start = high_resolution_clock::now(); + size_t termsProcessed = 0; + + std::filesystem::path exePath = std::filesystem::absolute(argv[0]).parent_path(); + std::filesystem::path projectRoot = exePath.parent_path().parent_path(); + std::string projectDir = projectRoot.string(); + + std::string dataDir = "/data"; + + const char* test_env = std::getenv("ENV"); // for integration tests, test with controlled and small dataset in test_data + if (test_env && std::string(test_env) == "TEST_ENV") { + std::cout << "TEST ENVIRONMENT, merging index with test data." << std::endl; + dataDir = "/test_data"; + } + std::string partialIndexPostingsDir = projectDir + dataDir + "/partial_indices/postings"; + std::string partialIndexDictDir = projectDir + dataDir + "/partial_indices/dictionaries"; + std::string metadataDir = projectDir + dataDir + "/index"; + std::string outputDir = (projectRoot.parent_path() / "index" / "bin").string(); // put in parallel directory index/ where python code expects it + std::filesystem::create_directories(outputDir); + + // copy from building dir to output dir + if (std::filesystem::exists(metadataDir + "/metadata.bin")) std::filesystem::remove(outputDir + "/metadata.bin"); + std::filesystem::copy(metadataDir + "/metadata.bin", outputDir + "/metadata.bin"); + + if (std::filesystem::exists(outputDir + "/docstore.bin")) std::filesystem::remove(outputDir + "/docstore.bin"); + if (std::filesystem::exists(outputDir + "/docstore_offsets.bin")) std::filesystem::remove(outputDir + "/docstore_offsets.bin"); + std::filesystem::copy(projectDir + dataDir + "/docstore/docstore.bin", outputDir + "/docstore.bin"); + std::filesystem::copy(projectDir + dataDir + "/docstore/docstore_offsets.bin", outputDir + "/docstore_offsets.bin"); + + // get number of partial indices + size_t partialIndexCount = 0; + while (true) { + std::string dictFile = partialIndexDictDir + "/dictionary_" + std::to_string(partialIndexCount) + ".bin"; + std::ifstream testIn(dictFile); + if (!testIn.is_open()) break; + testIn.close(); + partialIndexCount++; + } + partialIndexCount++; + std::cout << "Found " << partialIndexCount << " partial indices to merge.\n"; + + auto allDicts = std::vector>(partialIndexCount); + for (size_t i = 0; i < partialIndexCount - 1; i++) { + std::string dictFile = partialIndexDictDir + "/dictionary_" + std::to_string(i) + ".bin"; + std::string postingsFile = partialIndexPostingsDir + "/postings_" + std::to_string(i) + ".bin"; + allDicts[i] = readDictionary(dictFile, postingsFile); + } + allDicts[partialIndexCount - 1] = readDictionary( + partialIndexDictDir + "/dictionary_final.bin", + partialIndexPostingsDir + "/postings_final.bin" + ); + + struct HeapEntry { + std::string term; + size_t dictIndex; + size_t entryIndex; + bool operator>(const HeapEntry& other) const { + return term > other.term; + } + }; + std::priority_queue, std::greater> minHeap; + // initialize heap with first entry from each dictionary + for (size_t i = 0; i < allDicts.size(); i++) { + if (!allDicts[i].empty()) { + minHeap.push({allDicts[i][0].term, i, 0}); + } + } + // open final output files + std::string finalPostingsFile = outputDir + "/postinglists.bin"; + std::string finalDictFile = outputDir + "/index.bin"; + std::cout << "Writing merged postings to " << finalPostingsFile << std::endl; + std::cout << "Writing merged dictionary to " << finalDictFile << std::endl; + std::ofstream finalPostOut(finalPostingsFile, std::ios::binary); + std::ofstream finalDictOut(finalDictFile, std::ios::binary); + uint64_t finalOffset = 0; + + while (!minHeap.empty()) { + auto current = minHeap.top(); + minHeap.pop(); + + const std::string& term = current.term; + + // collect all postings for this term from all dictionaries + std::vector> postingsToMerge; + std::vector docFreqs; + + // add the current entry's posting + { + size_t dictIndex = current.dictIndex; + size_t entryIndex = current.entryIndex; + const DictEntry& entry = allDicts[dictIndex][entryIndex]; + docFreqs.push_back(entry.docFreq); + + std::string postingsFile = (dictIndex == partialIndexCount - 1) + ? partialIndexPostingsDir + "/postings_final.bin" + : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; + + postingsToMerge.push_back(readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); + + if (entryIndex + 1 < allDicts[dictIndex].size()) { + const DictEntry& nextEntry = allDicts[dictIndex][entryIndex + 1]; + minHeap.push({nextEntry.term, dictIndex, entryIndex + 1}); + } + } + + // check if the next entries in the heap have the same term + while (!minHeap.empty() && minHeap.top().term == term) { + auto same = minHeap.top(); + minHeap.pop(); + + size_t dictIndex = same.dictIndex; + size_t entryIndex = same.entryIndex; + const DictEntry& entry = allDicts[dictIndex][entryIndex]; + docFreqs.push_back(entry.docFreq); + + std::string postingsFile = (dictIndex == partialIndexCount - 1) + ? partialIndexPostingsDir + "/postings_final.bin" + : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; + + postingsToMerge.push_back(readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); + + // push next entry from this dictionary into the heap + if (entryIndex + 1 < allDicts[dictIndex].size()) { + const DictEntry& nextEntry = allDicts[dictIndex][entryIndex + 1]; + minHeap.push({nextEntry.term, dictIndex, entryIndex + 1}); + } + } + + // merge postings properly (sorted by docId, with positions merged and sorted) + std::vector mergedPostings = mergePostings(postingsToMerge); + + // calculate actual docFreq (number of unique documents) + uint32_t totalDocFreq = mergedPostings.size(); + + // write posting data to final postings file + writePosting(finalPostOut, mergedPostings); + + // calculate size of merged posting + uint64_t postingSize = 0; + for (const auto& entry : mergedPostings) { + postingSize += sizeof(uint32_t) + sizeof(uint32_t) + entry.positions.size() * sizeof(uint32_t); + } + + // write dictionary entry to final dictionary file + uint32_t termLen = term.size(); + finalDictOut.write(reinterpret_cast(&termLen), sizeof(termLen)); + finalDictOut.write(term.data(), termLen); + finalDictOut.write(reinterpret_cast(&finalOffset), sizeof(finalOffset)); + finalDictOut.write(reinterpret_cast(&totalDocFreq), sizeof(totalDocFreq)); + + finalOffset += postingSize; + + termsProcessed++; + + if (termsProcessed % 100000 == 0) { + auto now = high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(now - start).count(); + std::cout << "Processed " << termsProcessed << " terms, elapsed time: " + << elapsed << "s" << std::endl; + } + } + + std::cout << "Merging completed successfully.\n"; + std::cout << "Time taken: " + << duration_cast(high_resolution_clock::now() - start).count() + << " seconds.\n"; + finalPostOut.close(); + finalDictOut.close(); + return 0; +} \ No newline at end of file diff --git a/src/backend/search_engine/index_builder/test_data/msmarco.tsv b/src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv similarity index 100% rename from src/backend/search_engine/index_builder/test_data/msmarco.tsv rename to src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv diff --git a/src/backend/search_engine/index_builder/test_data/msmarco.tsv.gz b/src/backend/search_engine/index_builder/test_data/msmarco.tsv.gz deleted file mode 100644 index ac7713a9ff349a2284042e9a61f7ec1f734a4b39..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 219 zcmV<103`n(iwFqP+$L!P|7~+^VRB<{E_8Et0Bq4q4uUWc2HV!6Ffsl$EUs9*BB(s@x@qg*xX&WwAU|{<;6$L5;XW0vbyZe=Ltfp9v>?#Lbo!eJK#T&Qkt(5n% VUccZA009600{}{;90d{r001ZOWQhO( diff --git a/src/backend/search_engine/scripts/build-index.sh b/src/backend/search_engine/scripts/build-index.sh index c120a26..08c9bbe 100755 --- a/src/backend/search_engine/scripts/build-index.sh +++ b/src/backend/search_engine/scripts/build-index.sh @@ -9,4 +9,6 @@ cd build cmake .. cmake --build . -./index_builder "$@" \ No newline at end of file +./index_builder "$@" + +./merge_partial_indices \ No newline at end of file diff --git a/src/backend/uv.lock b/src/backend/uv.lock index 1167892..8d79a12 100644 --- a/src/backend/uv.lock +++ b/src/backend/uv.lock @@ -40,10 +40,8 @@ source = { editable = "." } dependencies = [ { name = "fastapi" }, { name = "pydantic" }, - { name = "pytest-cov" }, { name = "requests" }, { name = "tqdm" }, - { name = "typer" }, { name = "uvicorn" }, ] @@ -61,10 +59,8 @@ dev = [ requires-dist = [ { name = "fastapi", specifier = ">=0.119.1" }, { name = "pydantic", specifier = ">=2.12.3" }, - { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "requests", specifier = ">=2.32.5" }, { name = "tqdm", specifier = ">=4.67.1" }, - { name = "typer", specifier = ">=0.20.0" }, { name = "uvicorn", specifier = ">=0.38.0" }, ] @@ -149,67 +145,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] -[[package]] -name = "coverage" -version = "7.11.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d2/59/9698d57a3b11704c7b89b21d69e9d23ecf80d538cabb536c8b63f4a12322/coverage-7.11.3.tar.gz", hash = "sha256:0f59387f5e6edbbffec2281affb71cdc85e0776c1745150a3ab9b6c1d016106b", size = 815210, upload-time = "2025-11-10T00:13:17.18Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6d/f6/d8572c058211c7d976f24dab71999a565501fb5b3cdcb59cf782f19c4acb/coverage-7.11.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84b892e968164b7a0498ddc5746cdf4e985700b902128421bb5cec1080a6ee36", size = 216694, upload-time = "2025-11-10T00:11:34.296Z" }, - { url = "https://files.pythonhosted.org/packages/4a/f6/b6f9764d90c0ce1bce8d995649fa307fff21f4727b8d950fa2843b7b0de5/coverage-7.11.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f761dbcf45e9416ec4698e1a7649248005f0064ce3523a47402d1bff4af2779e", size = 217065, upload-time = "2025-11-10T00:11:36.281Z" }, - { url = "https://files.pythonhosted.org/packages/a5/8d/a12cb424063019fd077b5be474258a0ed8369b92b6d0058e673f0a945982/coverage-7.11.3-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1410bac9e98afd9623f53876fae7d8a5db9f5a0ac1c9e7c5188463cb4b3212e2", size = 248062, upload-time = "2025-11-10T00:11:37.903Z" }, - { url = "https://files.pythonhosted.org/packages/7f/9c/dab1a4e8e75ce053d14259d3d7485d68528a662e286e184685ea49e71156/coverage-7.11.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:004cdcea3457c0ea3233622cd3464c1e32ebba9b41578421097402bee6461b63", size = 250657, upload-time = "2025-11-10T00:11:39.509Z" }, - { url = "https://files.pythonhosted.org/packages/3f/89/a14f256438324f33bae36f9a1a7137729bf26b0a43f5eda60b147ec7c8c7/coverage-7.11.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8f067ada2c333609b52835ca4d4868645d3b63ac04fb2b9a658c55bba7f667d3", size = 251900, upload-time = "2025-11-10T00:11:41.372Z" }, - { url = "https://files.pythonhosted.org/packages/04/07/75b0d476eb349f1296486b1418b44f2d8780cc8db47493de3755e5340076/coverage-7.11.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:07bc7745c945a6d95676953e86ba7cebb9f11de7773951c387f4c07dc76d03f5", size = 248254, upload-time = "2025-11-10T00:11:43.27Z" }, - { url = "https://files.pythonhosted.org/packages/5a/4b/0c486581fa72873489ca092c52792d008a17954aa352809a7cbe6cf0bf07/coverage-7.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:8bba7e4743e37484ae17d5c3b8eb1ce78b564cb91b7ace2e2182b25f0f764cb5", size = 250041, upload-time = "2025-11-10T00:11:45.274Z" }, - { url = "https://files.pythonhosted.org/packages/af/a3/0059dafb240ae3e3291f81b8de00e9c511d3dd41d687a227dd4b529be591/coverage-7.11.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fbffc22d80d86fbe456af9abb17f7a7766e7b2101f7edaacc3535501691563f7", size = 248004, upload-time = "2025-11-10T00:11:46.93Z" }, - { url = "https://files.pythonhosted.org/packages/83/93/967d9662b1eb8c7c46917dcc7e4c1875724ac3e73c3cb78e86d7a0ac719d/coverage-7.11.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:0dba4da36730e384669e05b765a2c49f39514dd3012fcc0398dd66fba8d746d5", size = 247828, upload-time = "2025-11-10T00:11:48.563Z" }, - { url = "https://files.pythonhosted.org/packages/4c/1c/5077493c03215701e212767e470b794548d817dfc6247a4718832cc71fac/coverage-7.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ae12fe90b00b71a71b69f513773310782ce01d5f58d2ceb2b7c595ab9d222094", size = 249588, upload-time = "2025-11-10T00:11:50.581Z" }, - { url = "https://files.pythonhosted.org/packages/7f/a5/77f64de461016e7da3e05d7d07975c89756fe672753e4cf74417fc9b9052/coverage-7.11.3-cp313-cp313-win32.whl", hash = "sha256:12d821de7408292530b0d241468b698bce18dd12ecaf45316149f53877885f8c", size = 219223, upload-time = "2025-11-10T00:11:52.184Z" }, - { url = "https://files.pythonhosted.org/packages/ed/1c/ec51a3c1a59d225b44bdd3a4d463135b3159a535c2686fac965b698524f4/coverage-7.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:6bb599052a974bb6cedfa114f9778fedfad66854107cf81397ec87cb9b8fbcf2", size = 220033, upload-time = "2025-11-10T00:11:53.871Z" }, - { url = "https://files.pythonhosted.org/packages/01/ec/e0ce39746ed558564c16f2cc25fa95ce6fc9fa8bfb3b9e62855d4386b886/coverage-7.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:bb9d7efdb063903b3fdf77caec7b77c3066885068bdc0d44bc1b0c171033f944", size = 218661, upload-time = "2025-11-10T00:11:55.597Z" }, - { url = "https://files.pythonhosted.org/packages/46/cb/483f130bc56cbbad2638248915d97b185374d58b19e3cc3107359715949f/coverage-7.11.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:fb58da65e3339b3dbe266b607bb936efb983d86b00b03eb04c4ad5b442c58428", size = 217389, upload-time = "2025-11-10T00:11:57.59Z" }, - { url = "https://files.pythonhosted.org/packages/cb/ae/81f89bae3afef75553cf10e62feb57551535d16fd5859b9ee5a2a97ddd27/coverage-7.11.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8d16bbe566e16a71d123cd66382c1315fcd520c7573652a8074a8fe281b38c6a", size = 217742, upload-time = "2025-11-10T00:11:59.519Z" }, - { url = "https://files.pythonhosted.org/packages/db/6e/a0fb897041949888191a49c36afd5c6f5d9f5fd757e0b0cd99ec198a324b/coverage-7.11.3-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8258f10059b5ac837232c589a350a2df4a96406d6d5f2a09ec587cbdd539655", size = 259049, upload-time = "2025-11-10T00:12:01.592Z" }, - { url = "https://files.pythonhosted.org/packages/d9/b6/d13acc67eb402d91eb94b9bd60593411799aed09ce176ee8d8c0e39c94ca/coverage-7.11.3-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4c5627429f7fbff4f4131cfdd6abd530734ef7761116811a707b88b7e205afd7", size = 261113, upload-time = "2025-11-10T00:12:03.639Z" }, - { url = "https://files.pythonhosted.org/packages/ea/07/a6868893c48191d60406df4356aa7f0f74e6de34ef1f03af0d49183e0fa1/coverage-7.11.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:465695268414e149bab754c54b0c45c8ceda73dd4a5c3ba255500da13984b16d", size = 263546, upload-time = "2025-11-10T00:12:05.485Z" }, - { url = "https://files.pythonhosted.org/packages/24/e5/28598f70b2c1098332bac47925806353b3313511d984841111e6e760c016/coverage-7.11.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:4ebcddfcdfb4c614233cff6e9a3967a09484114a8b2e4f2c7a62dc83676ba13f", size = 258260, upload-time = "2025-11-10T00:12:07.137Z" }, - { url = "https://files.pythonhosted.org/packages/0e/58/58e2d9e6455a4ed746a480c4b9cf96dc3cb2a6b8f3efbee5efd33ae24b06/coverage-7.11.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:13b2066303a1c1833c654d2af0455bb009b6e1727b3883c9964bc5c2f643c1d0", size = 261121, upload-time = "2025-11-10T00:12:09.138Z" }, - { url = "https://files.pythonhosted.org/packages/17/57/38803eefb9b0409934cbc5a14e3978f0c85cb251d2b6f6a369067a7105a0/coverage-7.11.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d8750dd20362a1b80e3cf84f58013d4672f89663aee457ea59336df50fab6739", size = 258736, upload-time = "2025-11-10T00:12:11.195Z" }, - { url = "https://files.pythonhosted.org/packages/a8/f3/f94683167156e93677b3442be1d4ca70cb33718df32a2eea44a5898f04f6/coverage-7.11.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:ab6212e62ea0e1006531a2234e209607f360d98d18d532c2fa8e403c1afbdd71", size = 257625, upload-time = "2025-11-10T00:12:12.843Z" }, - { url = "https://files.pythonhosted.org/packages/87/ed/42d0bf1bc6bfa7d65f52299a31daaa866b4c11000855d753857fe78260ac/coverage-7.11.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a6b17c2b5e0b9bb7702449200f93e2d04cb04b1414c41424c08aa1e5d352da76", size = 259827, upload-time = "2025-11-10T00:12:15.128Z" }, - { url = "https://files.pythonhosted.org/packages/d3/76/5682719f5d5fbedb0c624c9851ef847407cae23362deb941f185f489c54e/coverage-7.11.3-cp313-cp313t-win32.whl", hash = "sha256:426559f105f644b69290ea414e154a0d320c3ad8a2bb75e62884731f69cf8e2c", size = 219897, upload-time = "2025-11-10T00:12:17.274Z" }, - { url = "https://files.pythonhosted.org/packages/10/e0/1da511d0ac3d39e6676fa6cc5ec35320bbf1cebb9b24e9ee7548ee4e931a/coverage-7.11.3-cp313-cp313t-win_amd64.whl", hash = "sha256:90a96fcd824564eae6137ec2563bd061d49a32944858d4bdbae5c00fb10e76ac", size = 220959, upload-time = "2025-11-10T00:12:19.292Z" }, - { url = "https://files.pythonhosted.org/packages/e5/9d/e255da6a04e9ec5f7b633c54c0fdfa221a9e03550b67a9c83217de12e96c/coverage-7.11.3-cp313-cp313t-win_arm64.whl", hash = "sha256:1e33d0bebf895c7a0905fcfaff2b07ab900885fc78bba2a12291a2cfbab014cc", size = 219234, upload-time = "2025-11-10T00:12:21.251Z" }, - { url = "https://files.pythonhosted.org/packages/84/d6/634ec396e45aded1772dccf6c236e3e7c9604bc47b816e928f32ce7987d1/coverage-7.11.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:fdc5255eb4815babcdf236fa1a806ccb546724c8a9b129fd1ea4a5448a0bf07c", size = 216746, upload-time = "2025-11-10T00:12:23.089Z" }, - { url = "https://files.pythonhosted.org/packages/28/76/1079547f9d46f9c7c7d0dad35b6873c98bc5aa721eeabceafabd722cd5e7/coverage-7.11.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fe3425dc6021f906c6325d3c415e048e7cdb955505a94f1eb774dafc779ba203", size = 217077, upload-time = "2025-11-10T00:12:24.863Z" }, - { url = "https://files.pythonhosted.org/packages/2d/71/6ad80d6ae0d7cb743b9a98df8bb88b1ff3dc54491508a4a97549c2b83400/coverage-7.11.3-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:4ca5f876bf41b24378ee67c41d688155f0e54cdc720de8ef9ad6544005899240", size = 248122, upload-time = "2025-11-10T00:12:26.553Z" }, - { url = "https://files.pythonhosted.org/packages/20/1d/784b87270784b0b88e4beec9d028e8d58f73ae248032579c63ad2ac6f69a/coverage-7.11.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9061a3e3c92b27fd8036dafa26f25d95695b6aa2e4514ab16a254f297e664f83", size = 250638, upload-time = "2025-11-10T00:12:28.555Z" }, - { url = "https://files.pythonhosted.org/packages/f5/26/b6dd31e23e004e9de84d1a8672cd3d73e50f5dae65dbd0f03fa2cdde6100/coverage-7.11.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:abcea3b5f0dc44e1d01c27090bc32ce6ffb7aa665f884f1890710454113ea902", size = 251972, upload-time = "2025-11-10T00:12:30.246Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ef/f9c64d76faac56b82daa036b34d4fe9ab55eb37f22062e68e9470583e688/coverage-7.11.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:68c4eb92997dbaaf839ea13527be463178ac0ddd37a7ac636b8bc11a51af2428", size = 248147, upload-time = "2025-11-10T00:12:32.195Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/5b666f90a8f8053bd264a1ce693d2edef2368e518afe70680070fca13ecd/coverage-7.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:149eccc85d48c8f06547534068c41d69a1a35322deaa4d69ba1561e2e9127e75", size = 249995, upload-time = "2025-11-10T00:12:33.969Z" }, - { url = "https://files.pythonhosted.org/packages/eb/7b/871e991ffb5d067f8e67ffb635dabba65b231d6e0eb724a4a558f4a702a5/coverage-7.11.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:08c0bcf932e47795c49f0406054824b9d45671362dfc4269e0bc6e4bff010704", size = 247948, upload-time = "2025-11-10T00:12:36.341Z" }, - { url = "https://files.pythonhosted.org/packages/0a/8b/ce454f0af9609431b06dbe5485fc9d1c35ddc387e32ae8e374f49005748b/coverage-7.11.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:39764c6167c82d68a2d8c97c33dba45ec0ad9172570860e12191416f4f8e6e1b", size = 247770, upload-time = "2025-11-10T00:12:38.167Z" }, - { url = "https://files.pythonhosted.org/packages/61/8f/79002cb58a61dfbd2085de7d0a46311ef2476823e7938db80284cedd2428/coverage-7.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:3224c7baf34e923ffc78cb45e793925539d640d42c96646db62dbd61bbcfa131", size = 249431, upload-time = "2025-11-10T00:12:40.354Z" }, - { url = "https://files.pythonhosted.org/packages/58/cc/d06685dae97468ed22999440f2f2f5060940ab0e7952a7295f236d98cce7/coverage-7.11.3-cp314-cp314-win32.whl", hash = "sha256:c713c1c528284d636cd37723b0b4c35c11190da6f932794e145fc40f8210a14a", size = 219508, upload-time = "2025-11-10T00:12:42.231Z" }, - { url = "https://files.pythonhosted.org/packages/5f/ed/770cd07706a3598c545f62d75adf2e5bd3791bffccdcf708ec383ad42559/coverage-7.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:c381a252317f63ca0179d2c7918e83b99a4ff3101e1b24849b999a00f9cd4f86", size = 220325, upload-time = "2025-11-10T00:12:44.065Z" }, - { url = "https://files.pythonhosted.org/packages/ee/ac/6a1c507899b6fb1b9a56069954365f655956bcc648e150ce64c2b0ecbed8/coverage-7.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:3e33a968672be1394eded257ec10d4acbb9af2ae263ba05a99ff901bb863557e", size = 218899, upload-time = "2025-11-10T00:12:46.18Z" }, - { url = "https://files.pythonhosted.org/packages/9a/58/142cd838d960cd740654d094f7b0300d7b81534bb7304437d2439fb685fb/coverage-7.11.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:f9c96a29c6d65bd36a91f5634fef800212dff69dacdb44345c4c9783943ab0df", size = 217471, upload-time = "2025-11-10T00:12:48.392Z" }, - { url = "https://files.pythonhosted.org/packages/bc/2c/2f44d39eb33e41ab3aba80571daad32e0f67076afcf27cb443f9e5b5a3ee/coverage-7.11.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2ec27a7a991d229213c8070d31e3ecf44d005d96a9edc30c78eaeafaa421c001", size = 217742, upload-time = "2025-11-10T00:12:50.182Z" }, - { url = "https://files.pythonhosted.org/packages/32/76/8ebc66c3c699f4de3174a43424c34c086323cd93c4930ab0f835731c443a/coverage-7.11.3-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:72c8b494bd20ae1c58528b97c4a67d5cfeafcb3845c73542875ecd43924296de", size = 259120, upload-time = "2025-11-10T00:12:52.451Z" }, - { url = "https://files.pythonhosted.org/packages/19/89/78a3302b9595f331b86e4f12dfbd9252c8e93d97b8631500888f9a3a2af7/coverage-7.11.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:60ca149a446da255d56c2a7a813b51a80d9497a62250532598d249b3cdb1a926", size = 261229, upload-time = "2025-11-10T00:12:54.667Z" }, - { url = "https://files.pythonhosted.org/packages/07/59/1a9c0844dadef2a6efac07316d9781e6c5a3f3ea7e5e701411e99d619bfd/coverage-7.11.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb5069074db19a534de3859c43eec78e962d6d119f637c41c8e028c5ab3f59dd", size = 263642, upload-time = "2025-11-10T00:12:56.841Z" }, - { url = "https://files.pythonhosted.org/packages/37/86/66c15d190a8e82eee777793cabde730640f555db3c020a179625a2ad5320/coverage-7.11.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac5d5329c9c942bbe6295f4251b135d860ed9f86acd912d418dce186de7c19ac", size = 258193, upload-time = "2025-11-10T00:12:58.687Z" }, - { url = "https://files.pythonhosted.org/packages/c7/c7/4a4aeb25cb6f83c3ec4763e5f7cc78da1c6d4ef9e22128562204b7f39390/coverage-7.11.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e22539b676fafba17f0a90ac725f029a309eb6e483f364c86dcadee060429d46", size = 261107, upload-time = "2025-11-10T00:13:00.502Z" }, - { url = "https://files.pythonhosted.org/packages/ed/91/b986b5035f23cf0272446298967ecdd2c3c0105ee31f66f7e6b6948fd7f8/coverage-7.11.3-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:2376e8a9c889016f25472c452389e98bc6e54a19570b107e27cde9d47f387b64", size = 258717, upload-time = "2025-11-10T00:13:02.747Z" }, - { url = "https://files.pythonhosted.org/packages/f0/c7/6c084997f5a04d050c513545d3344bfa17bd3b67f143f388b5757d762b0b/coverage-7.11.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:4234914b8c67238a3c4af2bba648dc716aa029ca44d01f3d51536d44ac16854f", size = 257541, upload-time = "2025-11-10T00:13:04.689Z" }, - { url = "https://files.pythonhosted.org/packages/3b/c5/38e642917e406930cb67941210a366ccffa767365c8f8d9ec0f465a8b218/coverage-7.11.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f0b4101e2b3c6c352ff1f70b3a6fcc7c17c1ab1a91ccb7a33013cb0782af9820", size = 259872, upload-time = "2025-11-10T00:13:06.559Z" }, - { url = "https://files.pythonhosted.org/packages/b7/67/5e812979d20c167f81dbf9374048e0193ebe64c59a3d93d7d947b07865fa/coverage-7.11.3-cp314-cp314t-win32.whl", hash = "sha256:305716afb19133762e8cf62745c46c4853ad6f9eeba54a593e373289e24ea237", size = 220289, upload-time = "2025-11-10T00:13:08.635Z" }, - { url = "https://files.pythonhosted.org/packages/24/3a/b72573802672b680703e0df071faadfab7dcd4d659aaaffc4626bc8bbde8/coverage-7.11.3-cp314-cp314t-win_amd64.whl", hash = "sha256:9245bd392572b9f799261c4c9e7216bafc9405537d0f4ce3ad93afe081a12dc9", size = 221398, upload-time = "2025-11-10T00:13:10.734Z" }, - { url = "https://files.pythonhosted.org/packages/f8/4e/649628f28d38bad81e4e8eb3f78759d20ac173e3c456ac629123815feb40/coverage-7.11.3-cp314-cp314t-win_arm64.whl", hash = "sha256:9a1d577c20b4334e5e814c3d5fe07fa4a8c3ae42a601945e8d7940bab811d0bd", size = 219435, upload-time = "2025-11-10T00:13:12.712Z" }, - { url = "https://files.pythonhosted.org/packages/19/8f/92bdd27b067204b99f396a1414d6342122f3e2663459baf787108a6b8b84/coverage-7.11.3-py3-none-any.whl", hash = "sha256:351511ae28e2509c8d8cae5311577ea7dd511ab8e746ffc8814a0896c3d33fbe", size = 208478, upload-time = "2025-11-10T00:13:14.908Z" }, -] - [[package]] name = "fastapi" version = "0.121.2" @@ -252,27 +187,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] -[[package]] -name = "markdown-it-py" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mdurl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, -] - -[[package]] -name = "mdurl" -version = "0.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, -] - [[package]] name = "mypy" version = "1.18.2" @@ -437,20 +351,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" }, ] -[[package]] -name = "pytest-cov" -version = "7.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "coverage" }, - { name = "pluggy" }, - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, -] - [[package]] name = "requests" version = "2.32.5" @@ -466,19 +366,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] -[[package]] -name = "rich" -version = "14.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown-it-py" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" }, -] - [[package]] name = "ruff" version = "0.14.5" @@ -505,15 +392,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/80/69756670caedcf3b9be597a6e12276a6cf6197076eb62aad0c608f8efce0/ruff-0.14.5-py3-none-win_arm64.whl", hash = "sha256:4b700459d4649e2594b31f20a9de33bc7c19976d4746d8d0798ad959621d64a4", size = 13433331, upload-time = "2025-11-13T19:58:48.434Z" }, ] -[[package]] -name = "shellingham" -version = "1.5.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -547,21 +425,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] -[[package]] -name = "typer" -version = "0.20.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "rich" }, - { name = "shellingham" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8f/28/7c85c8032b91dbe79725b6f17d2fffc595dff06a35c7a30a37bef73a1ab4/typer-0.20.0.tar.gz", hash = "sha256:1aaf6494031793e4876fb0bacfa6a912b551cf43c1e63c800df8b1a866720c37", size = 106492, upload-time = "2025-10-20T17:03:49.445Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/64/7713ffe4b5983314e9d436a90d5bd4f63b6054e2aca783a3cfc44cb95bbf/typer-0.20.0-py3-none-any.whl", hash = "sha256:5b463df6793ec1dca6213a3cf4c0f03bc6e322ac5e16e13ddd622a889489784a", size = 47028, upload-time = "2025-10-20T17:03:47.617Z" }, -] - [[package]] name = "types-requests" version = "2.32.4.20250913" diff --git a/tests/test_index_builder/test_index_builder.py b/tests/test_index_builder/test_index_builder.py index f6fa11f..447f416 100644 --- a/tests/test_index_builder/test_index_builder.py +++ b/tests/test_index_builder/test_index_builder.py @@ -77,6 +77,34 @@ } +def test_real_index_metadata(): + inverted_index = get_index() + metadata = inverted_index.metadata + + expected_num_docs = 10 + # counting title + body + expected_doc_lengths = { + 0: 6, + 1: 5, + 2: 6, + 3: 6, + 4: 5, + 5: 3, + 6: 4, + 7: 5, + 8: 4, + 9: 6, + } + expected_avg_length = sum(expected_doc_lengths.values()) / expected_num_docs + + assert metadata.num_docs == expected_num_docs + assert metadata.avg_doc_length == expected_avg_length + assert metadata.doc_lengths == expected_doc_lengths + + for doc_id, length in expected_doc_lengths.items(): + assert metadata.get_doc_length(doc_id) == length + + @pytest.mark.parametrize("term", ["alpha", "beta", "gamma", "delta"]) def test_real_index_postings(term): inverted_index = get_index() From 257f527bf01c2de1a463f4bebba0942714ea493c Mon Sep 17 00:00:00 2001 From: JanSkn Date: Sun, 21 Dec 2025 15:13:44 +0100 Subject: [PATCH 2/5] improve devops and add snippets --- .dockerignore | 12 +- .github/actions/setup/action.yml | 2 +- docker-compose.yml | 19 + justfile | 18 +- src/backend/.clang-format | 3 + src/backend/Dockerfile | 29 + src/backend/bindings/utils.cpp | 545 +++++++++++++----- .../index_builder/index_builder.cpp | 185 +++--- .../index_builder/merge_partial_indices.cpp | 181 +++--- src/backend/search_engine/models/index.py | 1 + .../search_engine/query/query_engine.py | 27 +- .../query/query_preprocessing.py | 7 +- src/frontend/Dockerfile | 11 + src/frontend/src/pages/Index.tsx | 18 +- src/frontend/vite.config.ts | 32 +- tests/Dockerfile | 20 +- tests/docker-compose.yml | 2 +- tests/entrypoint.sh | 6 +- 18 files changed, 727 insertions(+), 391 deletions(-) create mode 100644 docker-compose.yml create mode 100644 src/backend/.clang-format create mode 100644 src/backend/Dockerfile create mode 100644 src/frontend/Dockerfile diff --git a/.dockerignore b/.dockerignore index ac62239..6d4e5a3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -50,16 +50,16 @@ dist/ .DS_Store Thumbs.db -# Node / frontend (falls frontend im Projekt ist) node_modules/ dist/ **/*.tsv **/*.gz # allow test data files as they are small and necessary for tests -!src/backend/search_engine/index_builder/test_data/*.tsv -!src/backend/search_engine/index_builder/test_data/*.gz +!search_engine/index_builder/test_data/*.tsv +!search_engine/index_builder/test_data/*.gz -/src/backend/search_engine/index_builder/build/ -/src/backend/search_engine/index_builder/data/ -/src/backend/search_engine/index/bin/ +search_engine/index_builder/build/ +search_engine/index_builder/data/ +search_engine/index/bin/ +search_engine/models/neuspell-scrnn-probwordnoise/ diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 7270cfa..45c6359 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -16,4 +16,4 @@ runs: shell: bash run: | sudo apt-get update -y - sudo apt-get install -y zlib1g-dev libstemmer-dev build-essential git + sudo apt-get install -y zlib1g-dev libstemmer-dev build-essential git clang-format diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..37ee3c2 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,19 @@ +version: "3.9" +name: "Seekr. Search Engine" +services: + backend: + build: src/backend + container_name: seekr-backend + ports: + - "8000:8000" # expose backend for FE-less API access + volumes: # bind mount index data + - ./src/backend/index_builder:/app/src/backend/index_builder + frontend: + build: src/frontend + container_name: seekr-frontend + depends_on: + - backend + ports: + - "8080:8080" + environment: + ENV: "DOCKER" \ No newline at end of file diff --git a/justfile b/justfile index caf8abf..9839ec8 100644 --- a/justfile +++ b/justfile @@ -18,6 +18,9 @@ local *uvicorn-args: chmod +x local.sh && \ ./local.sh {{uvicorn-args}} +deploy: + docker compose up -d + build-index memory-limit="1024" max-docs="-1": cd src/backend/search_engine/scripts/ && \ chmod +x build-index.sh && \ @@ -37,12 +40,23 @@ generate-stubs: ./generate-stubs.sh lint: - @echo "Linting with Ruff..." + @echo "Linting Python code..." cd src/backend && uv run ruff check api/ search_engine/ ../../tests/ cd src/backend && uv run ruff format --check --diff api/ search_engine/ ../../tests/ + @echo "Linting C++ code..." # only format-check instead of linting to avoid dependency-related failures + clang-format --dry-run --Werror \ + src/backend/bindings/utils.cpp \ + src/backend/search_engine/index_builder/index_builder.cpp \ + src/backend/search_engine/index_builder/merge_partial_indices.cpp format: + @echo "Formatting Python code..." cd src/backend && uv run ruff format api/ search_engine/ ../../tests/ + @echo "Formatting C++ code..." + clang-format -i \ + src/backend/bindings/utils.cpp \ + src/backend/search_engine/index_builder/index_builder.cpp \ + src/backend/search_engine/index_builder/merge_partial_indices.cpp mypy: @echo "Type checking with MyPy..." @@ -50,4 +64,4 @@ mypy: cd src/backend && uv run mypy search_engine/ test: - just -f tests/justfile test \ No newline at end of file + just -f src/backend/tests/justfile test \ No newline at end of file diff --git a/src/backend/.clang-format b/src/backend/.clang-format new file mode 100644 index 0000000..4ca6bcb --- /dev/null +++ b/src/backend/.clang-format @@ -0,0 +1,3 @@ +BasedOnStyle: Google +IndentWidth: 4 +ColumnLimit: 100 diff --git a/src/backend/Dockerfile b/src/backend/Dockerfile new file mode 100644 index 0000000..dd7d502 --- /dev/null +++ b/src/backend/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.13-slim + +# system dependencies +# build-essential for packages using c extensions +# git for packages installed from git +# libstemmer for CMake build of index_builder +RUN apt-get update && apt-get install -y \ + build-essential \ + git \ + curl \ + cmake \ + libstemmer-dev \ + && rm -rf /var/lib/apt/lists/* + +ENV PATH="/root/.local/bin:$PATH" + +# set workdir to where pyproject.toml is located for uv +WORKDIR /app/src/backend + +# copy first to cache dependencies +COPY pyproject.toml . +COPY uv.lock . +COPY bindings/ ./bindings/ + +RUN uv sync + +COPY . . + +CMD ["uv", "run", "uvicorn", "api.v1.app:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/src/backend/bindings/utils.cpp b/src/backend/bindings/utils.cpp index a6618e8..3dd14ff 100644 --- a/src/backend/bindings/utils.cpp +++ b/src/backend/bindings/utils.cpp @@ -1,27 +1,26 @@ #include -#include // for automatic conversion of STL containers -#include -#include +#include // for automatic conversion of STL containers + +#include #include #include -#include -#include #include +#include +#include +#include +#include + #include "libstemmer.h" namespace py = pybind11; struct SnowballStemmer { struct sb_stemmer* stemmer; - SnowballStemmer() { - stemmer = sb_stemmer_new("english", nullptr); - } - ~SnowballStemmer() { - sb_stemmer_delete(stemmer); - } + SnowballStemmer() { stemmer = sb_stemmer_new("english", nullptr); } + ~SnowballStemmer() { sb_stemmer_delete(stemmer); } std::string stem(const std::string& word) { - const sb_symbol* stemmed = sb_stemmer_stem(stemmer, - reinterpret_cast(word.c_str()), word.size()); + const sb_symbol* stemmed = + sb_stemmer_stem(stemmer, reinterpret_cast(word.c_str()), word.size()); int out_len = sb_stemmer_length(stemmer); if (stemmed == nullptr || out_len <= 0) return std::string(); return std::string(reinterpret_cast(stemmed), static_cast(out_len)); @@ -33,14 +32,14 @@ const std::unordered_set KEEP_TOKENS = {"AND", "&", "OR", "|", "NOT std::vector normalize_search_query(const std::string& text) { std::vector tokens; - std::string token; // lowercase version for stemming - std::string token_original; // exact original casing + std::string token; // lowercase version for stemming + std::string token_original; // exact original casing auto flush_token = [&]() { if (token.empty()) return; if (KEEP_TOKENS.find(token_original) != KEEP_TOKENS.end()) { - tokens.push_back(token_original); // keep original casing for operators/parentheses + tokens.push_back(token_original); // keep original casing for operators/parentheses } else { tokens.push_back(stemmer.stem(token)); } @@ -52,7 +51,7 @@ std::vector normalize_search_query(const std::string& text) { for (char c : text) { if (std::isalnum(static_cast(c))) { token += std::tolower(static_cast(c)); - token_original += c; // keep original case + token_original += c; // keep original case continue; } @@ -67,7 +66,7 @@ std::vector normalize_search_query(const std::string& text) { std::string special(1, c); if (KEEP_TOKENS.find(special) != KEEP_TOKENS.end()) { - tokens.push_back(special); // operator/punctuation + tokens.push_back(special); // operator/punctuation } } @@ -113,11 +112,9 @@ struct PostingList { PostingList() = default; - PostingList( - const std::vector& p, - const std::unordered_map& tf, - const std::unordered_map>& pos - ) : postings(p), term_frequencies(tf), positions(pos) {} + PostingList(const std::vector& p, const std::unordered_map& tf, + const std::unordered_map>& pos) + : postings(p), term_frequencies(tf), positions(pos) {} void build_skip_pointers() { size_t skip_interval = static_cast(std::sqrt(postings.size())); @@ -129,119 +126,111 @@ struct PostingList { } }; -PostingList read_posting_list(std::ifstream& in, uint64_t offset, uint32_t docFreq) { +PostingList read_posting_list(std::ifstream& in, uint64_t offset, uint32_t doc_freq) { PostingList pl; - pl.doc_frequency = docFreq; + pl.doc_frequency = doc_freq; in.seekg(offset); - - pl.postings.resize(docFreq); - - for (uint32_t i = 0; i < docFreq; i++) { + + pl.postings.resize(doc_freq); + + for (uint32_t i = 0; i < doc_freq; i++) { uint32_t doc_id, pos_count; in.read(reinterpret_cast(&doc_id), sizeof(doc_id)); in.read(reinterpret_cast(&pos_count), sizeof(pos_count)); - + pl.postings[i] = doc_id; pl.term_frequencies[doc_id] = pos_count; - + std::vector positions(pos_count); in.read(reinterpret_cast(positions.data()), pos_count * sizeof(uint32_t)); pl.positions[doc_id] = std::move(positions); } - + pl.build_skip_pointers(); - + return pl; } struct DocInfo { std::string url; std::string title; + std::string snippet; DocInfo() = default; - DocInfo(const std::string& u, const std::string& t) - : url(u), title(t) {} + DocInfo(const std::string& u, const std::string& t, const std::string& s) + : url(u), title(t), snippet(s) {} }; +class InvertedIndex; // forward declaration + class DocStore { -private: - std::unordered_map offsets; + private: + InvertedIndex* parent; + struct DocOffset { + uint64_t docstore_offset; // docstore data offset + uint64_t tsv_offset; // offset into the msmarco tsv for body retrieval + }; + std::unordered_map offsets; std::ifstream data_in; - uint32_t total_docs; + std::ifstream tsv_in; + uint32_t total_docs = 0; -public: - void open(const std::string& dir_name) { - data_in.open(dir_name + "/docstore.bin", std::ios::binary); - std::ifstream off(dir_name + "/docstore_offsets.bin", std::ios::binary); + struct Hit { + uint32_t pos; + std::string term; + }; - if (!data_in || !off) - throw std::runtime_error("Could not open docstore"); + struct SubsnippetResult { + uint32_t start; + uint32_t end; + std::vector remaining_hits; + }; - // docCount at the beginning - data_in.read(reinterpret_cast(&total_docs), sizeof(total_docs)); + SubsnippetResult find_subsnippet(const std::vector& hits, int max_window_size, + size_t required_term_count); - while (true) { - uint32_t id; - uint64_t off64; + public: + std::vector query_terms; - if (!off.read(reinterpret_cast(&id), sizeof(id))) break; - if (!off.read(reinterpret_cast(&off64), sizeof(off64))) break; + DocStore(InvertedIndex* p) : parent(p) {} - offsets[id] = off64; - } - } - - std::optional get(uint32_t doc_id) { + void open(const std::string& dir_name); + std::string load_snippet(uint32_t doc_id, + std::vector>& snippet_windows, + uint64_t tsv_offset); + std::string get_snippet(uint32_t doc_id, uint64_t tsv_offset); + std::optional get(uint32_t doc_id); + std::optional get_tsv_offset(uint32_t doc_id) { auto it = offsets.find(doc_id); if (it == offsets.end()) return std::nullopt; - - uint64_t offset = it->second; - data_in.seekg(offset); - - uint32_t url_len; - data_in.read(reinterpret_cast(&url_len), sizeof(url_len)); - - std::string url(url_len, '\0'); - data_in.read(url.data(), url_len); - - uint32_t title_len; - data_in.read(reinterpret_cast(&title_len), sizeof(title_len)); - - std::string title(title_len, '\0'); - data_in.read(title.data(), title_len); - - return DocInfo{url, title}; + return it->second.tsv_offset; } - uint32_t size() const { return total_docs; } }; -class InvertedIndex; // forward - class IndexAccessor { -private: + private: InvertedIndex* parent; -public: + + public: IndexAccessor(InvertedIndex* p) : parent(p) {} std::optional get(const std::string& term); }; class InvertedIndex { -private: + private: std::unordered_map term_to_offset; std::unordered_map term_to_docfreq; std::ifstream postings_file; -public: + public: Metadata metadata; DocStore doc_store; IndexAccessor index; - InvertedIndex(const std::string& base_path) - : index(this) - { + InvertedIndex(const std::string& base_path) : doc_store(this), index(this) { std::ifstream index_file(base_path + "/index.bin", std::ios::binary); while (true) { uint32_t term_len; @@ -253,11 +242,11 @@ class InvertedIndex { uint64_t offset; if (!index_file.read(reinterpret_cast(&offset), sizeof(offset))) break; - uint32_t docFreq; - index_file.read(reinterpret_cast(&docFreq), sizeof(docFreq)); - + uint32_t doc_freq; + index_file.read(reinterpret_cast(&doc_freq), sizeof(doc_freq)); + term_to_offset[term] = offset; - term_to_docfreq[term] = docFreq; + term_to_docfreq[term] = doc_freq; } postings_file.open(base_path + "/postinglists.bin", std::ios::binary); @@ -267,22 +256,292 @@ class InvertedIndex { doc_store.open(base_path); } + friend class DocStore; friend class IndexAccessor; }; +// --- Docstore --- +void DocStore::open(const std::string& dir_name) { + data_in.open(dir_name + "/docstore.bin", std::ios::binary); + tsv_in.open(dir_name + "../../index_builder/data/msmarco-docs.tsv", std::ios::binary); + std::ifstream off(dir_name + "/docstore_offsets.bin", std::ios::binary); + + if (!data_in || !tsv_in || !off) throw std::runtime_error("Could not open docstore"); + + // docCount at the beginning + data_in.read(reinterpret_cast(&total_docs), sizeof(total_docs)); + + while (true) { + uint32_t id; + uint64_t off64; + uint64_t tsvOff; + + if (!off.read(reinterpret_cast(&id), sizeof(id))) break; + if (!off.read(reinterpret_cast(&off64), sizeof(off64))) break; + if (!off.read(reinterpret_cast(&tsvOff), sizeof(tsvOff))) break; + + offsets[id] = {off64, tsvOff}; + } +} + +std::string DocStore::load_snippet(uint32_t doc_id, + std::vector>& snippet_windows, + uint64_t tsv_offset) { + if (snippet_windows.empty()) return ""; + + std::sort(snippet_windows.begin(), snippet_windows.end()); + tsv_in.clear(); + tsv_in.seekg(tsv_offset); + std::string line; + if (!std::getline(tsv_in, line)) { + return ""; + } + + // parse line: [DocID] \t [URL] \t [Title] \t [Content] + // we need to find the 3rd tab to get to content + size_t pos = 0; + int tab_count = 0; + while (tab_count < 3) { + pos = line.find('\t', pos); + if (pos == std::string::npos) return ""; // invalid format + pos++; // skip the tab + tab_count++; + } + + size_t content_start = pos; + size_t len = line.size(); + std::string snippet; + snippet.reserve(200); + uint32_t current_word_pos = 0; + size_t i = content_start; + size_t window_idx = 0; + + if (snippet_windows[0].first > 0) { + snippet += "... "; + } + + // helper lambda to check if character is sentence-ending punctuation + auto is_sentence_end = [](char c) { return c == '.' || c == '!' || c == '?'; }; + + // Calculate threshold for last window (last 10%) + uint32_t last_window_start = snippet_windows.back().first; + uint32_t last_window_end = snippet_windows.back().second; + uint32_t last_window_size = last_window_end - last_window_start + 1; + uint32_t last_window_threshold = last_window_end - (last_window_size / 10); + + bool stopped_at_sentence_end = false; + + while (i < len && window_idx < snippet_windows.size()) { + // --- determine word --- + size_t word_start = i; + while (word_start < len && !std::isalpha(static_cast(line[word_start]))) { + word_start++; + } + std::string separator = line.substr(i, word_start - i); + if (word_start >= len) { + // No more words + break; + } + size_t word_end = word_start; + while (word_end < len && std::isalpha(static_cast(line[word_end]))) { + word_end++; + } + std::string word = line.substr(word_start, word_end - word_start); + // --------------------- + + // check if current word is in relevant window + // skip windows that are already passed + while (window_idx < snippet_windows.size() && + current_word_pos > snippet_windows[window_idx].second) { + window_idx++; + if (window_idx < snippet_windows.size()) { + snippet += " ... "; + } + } + + if (window_idx < snippet_windows.size()) { + uint32_t w_start = snippet_windows[window_idx].first; + uint32_t w_end = snippet_windows[window_idx].second; + + if (current_word_pos >= w_start && current_word_pos <= w_end) { + if (current_word_pos == w_start) { + snippet += word; + } else { + snippet += separator + word; + } + + // check if we're in the last window and in its last 10% + bool is_last_window = (window_idx == snippet_windows.size() - 1); + if (is_last_window && current_word_pos >= last_window_threshold && + current_word_pos < w_end) { + // look for sentence-ending punctuation after this word + size_t check_pos = word_end; + while (check_pos < len && check_pos < word_end + 3) { + if (is_sentence_end(line[check_pos])) { + // add it and stop early + snippet += line[check_pos]; + stopped_at_sentence_end = true; + break; + } + check_pos++; + } + if (stopped_at_sentence_end) { + break; + } + } + } + } + + // advance + i = word_end; + current_word_pos++; + } + + // check if there is more text after the snippets (only if we didn't stop at sentence end) + if (!stopped_at_sentence_end && window_idx >= snippet_windows.size()) { + size_t check = i; + while (check < len && !std::isalpha(static_cast(line[check]))) check++; + if (check < len) { + snippet += " ..."; + } + } + + return snippet; +} + +DocStore::SubsnippetResult DocStore::find_subsnippet(const std::vector& hits, + int max_window_size, + size_t required_term_count) { + SubsnippetResult result{}; + result.start = 0; + result.end = 0; + + if (hits.empty()) return result; + + std::unordered_map window_term_count; + + uint32_t left = 0; + uint32_t best_start = hits[0].pos; + uint32_t best_end = hits[0].pos; + uint32_t best_score = 0; + + // mark the best window indices + uint32_t best_left_idx = 0; + uint32_t best_right_idx = 0; + + for (uint32_t right = 0; right < hits.size(); ++right) { + window_term_count[hits[right].term]++; + + // shrink window if too large + while (hits[right].pos - hits[left].pos > max_window_size) { + auto& c = window_term_count[hits[left].term]; + if (--c == 0) window_term_count.erase(hits[left].term); + left++; + } + + // score is number of unique terms in this window + uint32_t score = window_term_count.size(); + + // update score or choose smaller snippet -> terms more together + if (score > best_score || + (score == best_score && (hits[right].pos - hits[left].pos) < (best_end - best_start))) { + best_score = score; + best_start = hits[left].pos; + best_end = hits[right].pos; + best_left_idx = left; + best_right_idx = right; + + if (best_score == required_term_count) break; // perfect snippet found + } + } + + result.start = best_start; + result.end = best_end; + + // collect remaining hits outside the best window + result.remaining_hits.reserve(hits.size()); + for (uint32_t i = 0; i < hits.size(); ++i) { + if (i < best_left_idx || i > best_right_idx) result.remaining_hits.push_back(hits[i]); + } + + return result; +} + +// total snippet length: max. MAX_WINDOW_SIZE x 2 + 1 or 2x "..." +std::string DocStore::get_snippet(uint32_t doc_id, uint64_t tsv_offset) { + uint32_t MAX_WINDOW_SIZE = 15; // num of words PER subsnippet + std::set unique_terms(query_terms.begin(), query_terms.end()); + + // e.g. + // hits = [ {pos: 3, term: "foo"}, {pos: 10, term: "bar"}, {pos: 15, term: "foo"}, {pos: 18, + // term: "bar"} ] + std::vector hits; + for (const auto& term : unique_terms) { + auto termIt = parent->term_to_offset.find(term); + if (termIt == parent->term_to_offset.end()) continue; + auto docIt = parent->term_to_docfreq.find(term); + PostingList pl = read_posting_list(parent->postings_file, termIt->second, docIt->second); + + auto posIt = pl.positions.find(doc_id); + if (posIt == pl.positions.end()) continue; + + for (uint32_t pos : posIt->second) hits.push_back(Hit{pos, term}); + } + std::sort(hits.begin(), hits.end(), [](const Hit& a, const Hit& b) { return a.pos < b.pos; }); + + // create first optimal snippet + SubsnippetResult first_snippet = find_subsnippet(hits, MAX_WINDOW_SIZE, unique_terms.size()); + + // can be one if all terms fit into MAX_WINDOW_SIZE, or at most 2 for remaining terms + // not more than 2 for readability + std::vector> snippet_windows; + snippet_windows.push_back({first_snippet.start, first_snippet.end}); + + if (!first_snippet.remaining_hits.empty()) { + SubsnippetResult second_snippet = + find_subsnippet(first_snippet.remaining_hits, MAX_WINDOW_SIZE, unique_terms.size()); + snippet_windows.push_back({second_snippet.start, second_snippet.end}); + } + + return load_snippet(doc_id, snippet_windows, tsv_offset); +} + +std::optional DocStore::get(uint32_t doc_id) { + auto it = offsets.find(doc_id); + if (it == offsets.end()) return std::nullopt; + + uint64_t docstore_offset = it->second.docstore_offset; + data_in.seekg(docstore_offset); + uint64_t tsv_offset = it->second.tsv_offset; + + uint32_t url_len; + data_in.read(reinterpret_cast(&url_len), sizeof(url_len)); + + std::string url(url_len, '\0'); + data_in.read(url.data(), url_len); + + uint32_t title_len; + data_in.read(reinterpret_cast(&title_len), sizeof(title_len)); + + std::string title(title_len, '\0'); + data_in.read(title.data(), title_len); + + std::string snippet = get_snippet(doc_id, tsv_offset); + + return DocInfo{url, title, snippet}; +} +// -------------------- + std::optional IndexAccessor::get(const std::string& term) { auto it = parent->term_to_offset.find(term); if (it == parent->term_to_offset.end()) return std::nullopt; - uint32_t docFreq = parent->term_to_docfreq.at(term); - PostingList pl = read_posting_list(parent->postings_file, it->second, docFreq); + uint32_t doc_freq = parent->term_to_docfreq.at(term); + PostingList pl = read_posting_list(parent->postings_file, it->second, doc_freq); return pl; } -PostingList positional_intersect( - const PostingList& pl1, - const PostingList& pl2, - uint32_t distance -) { +PostingList positional_intersect(const PostingList& pl1, const PostingList& pl2, + uint32_t distance) { PostingList result; const auto& p1 = pl1.postings; @@ -296,7 +555,7 @@ PostingList positional_intersect( const size_t n1 = p1.size(); const size_t n2 = p2.size(); - result.postings.reserve(std::min(n1, n2)); // conservative + result.postings.reserve(std::min(n1, n2)); // conservative while (i < n1 && j < n2) { uint32_t doc1 = p1[i]; @@ -335,7 +594,6 @@ PostingList positional_intersect( if (!valid_positions.empty()) { result.postings.push_back(doc_id); - result.term_frequencies[doc_id] = valid_positions.size(); result.positions[doc_id] = std::move(valid_positions); } } @@ -354,7 +612,7 @@ PostingList positional_intersect( } } - else { // doc2 < doc1 + else { // doc2 < doc1 // skip pointer support for pl2 auto it_s2 = skip2.find(j); if (it_s2 != skip2.end() && it_s2->second < n2 && p2[it_s2->second] <= doc1) { @@ -370,12 +628,7 @@ PostingList positional_intersect( return result; } -// faster if left posting list is smaller -PostingList find_docs( - const PostingList& pl1, - const PostingList& pl2, - const std::string& mode -) { +PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std::string& mode) { const auto& p1 = pl1.postings; const auto& p2 = pl2.postings; @@ -389,33 +642,27 @@ PostingList find_docs( const size_t n1 = p1.size(); const size_t n2 = p2.size(); - std::vector result_postings; - result_postings.reserve(std::min(n1, n2)); // most likely - - std::unordered_map result_tf; - if (mode == "AND") { + std::vector intersected; + intersected.reserve(std::min(n1, n2)); // most likely while (i < n1 && j < n2) { uint32_t d1 = p1[i]; uint32_t d2 = p2[j]; if (d1 == d2) { - result_postings.push_back(d1); - result_tf[d1] = tf1.at(d1) + tf2.at(d2); + intersected.push_back(d1); i++; j++; - } - else if (d1 < d2) { + } else if (d1 < d2) { auto it = skip1.find(i); if (it != skip1.end() && p1[it->second] <= d2) { i = it->second; } else { i++; } - } - else { // d2 < d1 + } else { // d2 < d1 auto it = skip2.find(j); if (it != skip2.end() && p2[it->second] <= d1) { j = it->second; @@ -425,7 +672,7 @@ PostingList find_docs( } } - PostingList out(result_postings, result_tf, {}); + PostingList out(intersected, {}, {}); out.build_skip_pointers(); return out; } @@ -442,17 +689,13 @@ PostingList find_docs( if (x == y) { merged.push_back(x); - result_tf[x] = tf1.at(x) + tf2.at(y); - a++; b++; - } - else if (x < y) { + a++; + b++; + } else if (x < y) { merged.push_back(x); - result_tf[x] = tf1.at(x); a++; - } - else { + } else { merged.push_back(y); - result_tf[y] = tf2.at(y); b++; } } @@ -461,16 +704,14 @@ PostingList find_docs( while (a < n1) { uint32_t x = p1[a++]; merged.push_back(x); - result_tf[x] = tf1.at(x); } while (b < n2) { uint32_t y = p2[b++]; merged.push_back(y); - result_tf[y] = tf2.at(y); } - PostingList out(merged, result_tf, {}); + PostingList out(merged, {}, {}); out.build_skip_pointers(); return out; } @@ -489,12 +730,11 @@ PostingList find_docs( if (b == n2 || p2[b] != x) { diff.push_back(x); - result_tf[x] = tf1.at(x); } a++; } - PostingList out(diff, result_tf, {}); + PostingList out(diff, {}, {}); out.build_skip_pointers(); return out; } @@ -505,41 +745,30 @@ PostingList find_docs( PYBIND11_MODULE(_core, m) { m.doc() = "CPP utils for search engine"; - m.def("normalize_search_query", &normalize_search_query, - py::arg("text"), - "Normalize and stem search query into tokens, but keep logical operators and parentheses as is"); + m.def("normalize_search_query", &normalize_search_query, py::arg("text"), + "Normalize and stem search query into tokens, but keep logical operators and parentheses " + "as is"); - m.def("positional_intersect", &positional_intersect, - py::arg("pl1"), py::arg("pl2"), py::arg("distance") = 1, - "Positional intersection of two posting lists with given distance"); + m.def("positional_intersect", &positional_intersect, py::arg("pl1"), py::arg("pl2"), + py::arg("distance") = 1, + "Positional intersection of two posting lists with given distance"); - m.def( - "find_docs", - &find_docs, - py::arg("pl1"), - py::arg("pl2"), - py::arg("mode"), - "Find documents that are in both posting lists" - ); + m.def("find_docs", &find_docs, py::arg("pl1"), py::arg("pl2"), py::arg("mode"), + "Find documents that are in both posting lists"); py::class_(m, "DocInfo") .def(py::init<>()) - .def(py::init(), - py::arg("url"), py::arg("title")) + .def(py::init(), py::arg("url"), + py::arg("title"), py::arg("snippet")) .def_readonly("url", &DocInfo::url) - .def_readonly("title", &DocInfo::title); + .def_readonly("title", &DocInfo::title) + .def_readonly("snippet", &DocInfo::snippet); py::class_(m, "PostingList") .def(py::init<>()) - .def(py::init< - const std::vector&, - const std::unordered_map&, - const std::unordered_map>& - >(), - py::arg("postings"), - py::arg("term_frequencies"), - py::arg("positions") - ) + .def(py::init&, const std::unordered_map&, + const std::unordered_map>&>(), + py::arg("postings"), py::arg("term_frequencies"), py::arg("positions")) .def_readonly("postings", &PostingList::postings) .def_readonly("term_frequencies", &PostingList::term_frequencies) .def_readonly("positions", &PostingList::positions) @@ -553,14 +782,14 @@ PYBIND11_MODULE(_core, m) { .def("get_doc_length", &Metadata::get_doc_length, py::arg("doc_id")); py::class_(m, "DocStore") - .def("get", &DocStore::get, py::arg("doc_id")); + .def("get", &DocStore::get, py::arg("doc_id")) + .def("get_tsv_offset", &DocStore::get_tsv_offset, py::arg("doc_id")); - py::class_(m, "IndexAccessor") - .def("get", &IndexAccessor::get, py::arg("term")); + py::class_(m, "IndexAccessor").def("get", &IndexAccessor::get, py::arg("term")); py::class_(m, "InvertedIndex") .def(py::init()) .def_readonly("index", &InvertedIndex::index) .def_readonly("metadata", &InvertedIndex::metadata) .def_readonly("doc_store", &InvertedIndex::doc_store); -} \ No newline at end of file +} diff --git a/src/backend/search_engine/index_builder/index_builder.cpp b/src/backend/search_engine/index_builder/index_builder.cpp index f15f567..d464078 100644 --- a/src/backend/search_engine/index_builder/index_builder.cpp +++ b/src/backend/search_engine/index_builder/index_builder.cpp @@ -1,45 +1,51 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include + +#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include + #include "include/robin_hood.h" -// not encoded as neglectably small class DocStoreWriter { -private: - std::ofstream outStream; + private: + std::ofstream outStream; std::ofstream offsetStream; - uint64_t currentByteOffset; // offset where next doc will be written/read + uint64_t currentByteOffset; // offset where next doc will be written/read uint32_t docCount; -public: + public: void init(const std::string& filename_base) { - outStream.open(filename_base + "/docstore.bin", std::ios::binary | std::ios::out | std::ios::trunc); - offsetStream.open(filename_base + "/docstore_offsets.bin", std::ios::binary | std::ios::out | std::ios::trunc); - - currentByteOffset = 0; + outStream.open(filename_base + "/docstore.bin", + std::ios::binary | std::ios::out | std::ios::trunc); + offsetStream.open(filename_base + "/docstore_offsets.bin", + std::ios::binary | std::ios::out | std::ios::trunc); + + currentByteOffset = 0; docCount = 0; - + outStream.write(reinterpret_cast(&docCount), sizeof(docCount)); currentByteOffset += sizeof(docCount); } - void addDocument(uint32_t docId, const std::string& url, const std::string& title) { + void addDocument(uint32_t docId, const std::string& url, const std::string& title, + uint64_t tsvOffset) { /* offsetStream: [0-3] docId = 42 - [4-11] offset = 0 (start of this doc in outStream) + [4-11] docStoreOffset = 0 (start of this doc in outStream) + [12-19] tsvOffset = ... (start of the line of this doc in original tsv) - [12-15] docId = 105 - [16-23] offset = 18 (start of this doc in outStream) + [20-23] docId = 105 + [24-31] docStoreOffset = 18 (start of this doc in outStream) + [32-39] tsvOffset = ... ... outStream: @@ -54,8 +60,10 @@ class DocStoreWriter { [35-36] 'H' 'i' */ offsetStream.write(reinterpret_cast(&docId), sizeof(docId)); - offsetStream.write(reinterpret_cast(¤tByteOffset), sizeof(currentByteOffset)); - + offsetStream.write(reinterpret_cast(¤tByteOffset), + sizeof(currentByteOffset)); + offsetStream.write(reinterpret_cast(&tsvOffset), sizeof(tsvOffset)); + uint32_t urlLen = url.size(); outStream.write(reinterpret_cast(&urlLen), sizeof(urlLen)); outStream.write(url.data(), urlLen); @@ -66,7 +74,7 @@ class DocStoreWriter { // faster than tellp() currentByteOffset += sizeof(uint32_t) + urlLen + sizeof(uint32_t) + titleLen; - + docCount++; } @@ -87,17 +95,16 @@ struct Posting { }; static const robin_hood::unordered_flat_set STOP_WORDS = { - "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", - "into", "is", "it", "no", "not", "of", "on", "or", "such", "that", "the", - "their", "then", "there", "these", "they", "this", "to", "was", "will", "with" -}; + "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", + "in", "into", "is", "it", "no", "not", "of", "on", "or", "such", "that", + "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"}; class Tokenizer { -private: + private: struct sb_stemmer* stemmer; std::string tokenBuffer; - -public: + + public: Tokenizer() { stemmer = sb_stemmer_new("english", "UTF_8"); if (!stemmer) { @@ -105,64 +112,58 @@ class Tokenizer { } tokenBuffer.reserve(64); } - + ~Tokenizer() { if (stemmer) sb_stemmer_delete(stemmer); } - + // non-copyable Tokenizer(const Tokenizer&) = delete; Tokenizer& operator=(const Tokenizer&) = delete; - - template + + template void tokenize(const char* text, size_t len, Callback&& callback) { int position = 0; size_t i = 0; - + while (i < len) { while (i < len && !std::isalpha(static_cast(text[i]))) { i++; } if (i >= len) break; - + tokenBuffer.clear(); while (i < len && std::isalpha(static_cast(text[i]))) { tokenBuffer.push_back(std::tolower(static_cast(text[i]))); i++; } - + if (tokenBuffer.empty()) continue; - + if (STOP_WORDS.count(tokenBuffer)) { position++; continue; } - - const sb_symbol* stemmed = sb_stemmer_stem( - stemmer, - reinterpret_cast(tokenBuffer.data()), - tokenBuffer.size() - ); + + const sb_symbol* stemmed = + sb_stemmer_stem(stemmer, reinterpret_cast(tokenBuffer.data()), + tokenBuffer.size()); int stemLen = sb_stemmer_length(stemmer); - + std::string term(reinterpret_cast(stemmed), stemLen); - + callback(std::move(term), position); position++; } } }; -void spillToDisk( - robin_hood::unordered_flat_map>& termPostings, - const robin_hood::unordered_flat_map& termDictionary, - const std::string& postingsFile, - const std::string& dictFile) -{ +void spillToDisk(robin_hood::unordered_flat_map>& termPostings, + const robin_hood::unordered_flat_map& termDictionary, + const std::string& postingsFile, const std::string& dictFile) { std::vector> sortedTerms; sortedTerms.reserve(termDictionary.size()); - for (const auto& kv : termDictionary) - sortedTerms.emplace_back(kv.first, kv.second); + for (const auto& kv : termDictionary) sortedTerms.emplace_back(kv.first, kv.second); // sort for more efficient merging of the spilled files later std::sort(sortedTerms.begin(), sortedTerms.end(), @@ -170,10 +171,9 @@ void spillToDisk( std::ofstream postOut(postingsFile, std::ios::binary); std::ofstream dictOut(dictFile, std::ios::binary); - if (!postOut || !dictOut) - throw std::runtime_error("Failed to open output files"); + if (!postOut || !dictOut) throw std::runtime_error("Failed to open output files"); - static char postBuffer[8 * 1024 * 1024]; // 8MB + static char postBuffer[8 * 1024 * 1024]; // 8MB static char dictBuffer[8 * 1024 * 1024]; postOut.rdbuf()->pubsetbuf(postBuffer, sizeof(postBuffer)); dictOut.rdbuf()->pubsetbuf(dictBuffer, sizeof(dictBuffer)); @@ -182,8 +182,7 @@ void spillToDisk( for (const auto& [term, termId] : sortedTerms) { auto it = termPostings.find(termId); - if (it == termPostings.end()) - continue; + if (it == termPostings.end()) continue; // sort for search and union of posting lists std::vector& postings = it->second; @@ -191,7 +190,7 @@ void spillToDisk( [](const Posting& a, const Posting& b) { return a.docId < b.docId; }); uint64_t startOffset = offset; - + uint32_t docFreq = postings.size(); for (const auto& posting : postings) { @@ -254,32 +253,37 @@ int main(int argc, char* argv[]) { std::filesystem::path projectRoot = exePath.parent_path().parent_path(); std::string dataDir = "/data"; - - const char* test_env = std::getenv("ENV"); // for integration tests, test with controlled and small dataset in test_data + + const char* test_env = std::getenv( + "ENV"); // for integration tests, test with controlled and small dataset in test_data if (test_env && std::string(test_env) == "TEST_ENV") { std::cout << "TEST ENVIRONMENT, building index with test data." << std::endl; dataDir = "/test_data"; - } + } std::string projectDir = projectRoot.string(); std::string partialIndexPostingsDir = projectDir + dataDir + "/partial_indices/postings"; std::string partialIndexDictDir = projectDir + dataDir + "/partial_indices/dictionaries"; - std::string metadataDir = projectDir + dataDir + "/index"; + std::string outputDir = + (projectRoot.parent_path() / "index" / "bin") + .string(); // put in parallel directory index/ where python code expects it + std::string metadataDir = outputDir + "/metadata.bin"; + std::string docstoreBase = outputDir; std::filesystem::create_directories(partialIndexPostingsDir); std::filesystem::create_directories(partialIndexDictDir); std::filesystem::create_directories(metadataDir); + std::filesystem::create_directories(outputDir); Tokenizer tokenizer; std::ifstream infile(projectDir + dataDir + "/msmarco-docs.tsv"); - + if (!infile.is_open()) { std::cerr << "Failed to open input file\n"; return 1; } DocStoreWriter docStore; - std::string docstoreBase = projectDir + dataDir + "/docstore"; std::filesystem::create_directories(docstoreBase); docStore.init(docstoreBase); @@ -299,14 +303,16 @@ int main(int argc, char* argv[]) { uint32_t partialIndexesCount = 0; size_t memoryBytes = 0; + size_t currentLineOffset = + infile.tellg(); // store tsv file offset to restore original doc content for snippets while (std::getline(infile, line)) { lineNumber++; if (memoryBytes > MEMORYLIMIT) { std::string postingsFile = partialIndexPostingsDir + "/postings_" + std::to_string(partialIndexesCount) + ".bin"; - std::string dictFile = partialIndexDictDir + "/dictionary_" + - std::to_string(partialIndexesCount) + ".bin"; + std::string dictFile = + partialIndexDictDir + "/dictionary_" + std::to_string(partialIndexesCount) + ".bin"; try { spillToDisk(termPostings, termDictionary, postingsFile, dictFile); } catch (const std::exception& e) { @@ -319,15 +325,16 @@ int main(int argc, char* argv[]) { memoryBytes = 0; auto elapsed = duration(high_resolution_clock::now() - start).count(); std::cout << "[Partial Index #" << partialIndexesCount << "] " - << "Lines processed: " << lineNumber - << " Time: " << elapsed << "s\n"; + << "Lines processed: " << lineNumber << " Time: " << elapsed << "s\n"; } size_t pos1 = line.find('\t'); size_t pos2 = line.find('\t', pos1 + 1); size_t pos3 = line.find('\t', pos2 + 1); - if (pos3 == std::string::npos) + if (pos3 == std::string::npos) { + currentLineOffset = infile.tellg(); continue; + } // parse docId int docId = -1; @@ -343,12 +350,15 @@ int main(int argc, char* argv[]) { } } } - if (docId < 0) continue; + if (docId < 0) { + currentLineOffset = infile.tellg(); + continue; + } std::string url = line.substr(pos1 + 1, pos2 - pos1 - 1); std::string title = line.substr(pos2 + 1, pos3 - pos2 - 1); - docStore.addDocument(docId, url, title); + docStore.addDocument(docId, url, title, currentLineOffset); // ensure docLengths vector is large enough if (static_cast(docId) >= docLengths.size()) { @@ -362,11 +372,11 @@ int main(int argc, char* argv[]) { size_t titleLen = pos3 - pos2 - 1; const char* contentStart = line.data() + pos3 + 1; size_t contentLen = line.size() - pos3 - 1; - + // process title tokenizer.tokenize(titleStart, titleLen, [&](std::string&& term, int position) { docTermCount++; - + uint32_t termId; auto it = termDictionary.find(term); if (it == termDictionary.end()) { @@ -388,11 +398,11 @@ int main(int argc, char* argv[]) { memoryBytes += sizeof(int); } }); - + // process content (positions continue from title) tokenizer.tokenize(contentStart, contentLen, [&](std::string&& term, int position) { docTermCount++; - + uint32_t termId; auto it = termDictionary.find(term); if (it == termDictionary.end()) { @@ -414,10 +424,12 @@ int main(int argc, char* argv[]) { memoryBytes += sizeof(int); } }); - + docLengths[docId] = docTermCount; - + if (maxDocs != -1 && lineNumber >= maxDocs) break; + + currentLineOffset = infile.tellg(); } // final flush if remaining data @@ -444,11 +456,11 @@ int main(int argc, char* argv[]) { std::cerr << "Failed to open metadata file for writing\n"; return 1; } - + // write header: numDocs, avgDocLength metaOut.write(reinterpret_cast(&numDocs), sizeof(numDocs)); metaOut.write(reinterpret_cast(&avgDocLength), sizeof(avgDocLength)); - + // write document lengths array (only non-zero entries with their docIds) for (size_t docId = 0; docId < docLengths.size(); docId++) { if (docLengths[docId] > 0) { @@ -460,7 +472,8 @@ int main(int argc, char* argv[]) { } metaOut.close(); - std::cout << "Metadata written: " << numDocs << " documents, avg length: " << avgDocLength << std::endl; + std::cout << "Metadata written: " << numDocs << " documents, avg length: " << avgDocLength + << std::endl; double totalTime = duration(high_resolution_clock::now() - start).count(); std::cout << "Indexing completed in " << totalTime << " seconds.\n"; diff --git a/src/backend/search_engine/index_builder/merge_partial_indices.cpp b/src/backend/search_engine/index_builder/merge_partial_indices.cpp index 919f993..f4873b3 100644 --- a/src/backend/search_engine/index_builder/merge_partial_indices.cpp +++ b/src/backend/search_engine/index_builder/merge_partial_indices.cpp @@ -1,16 +1,16 @@ -#include -#include -#include -#include -#include -#include #include + #include -#include #include +#include +#include #include +#include +#include #include - +#include +#include +#include struct DictEntry { std::string term; @@ -24,45 +24,46 @@ struct PostingEntry { std::vector positions; }; -std::vector readDictionary(const std::string& dictFile, const std::string& postingsFile) { +std::vector readDictionary(const std::string& dictFile, + const std::string& postingsFile) { std::vector entries; - + std::ifstream dictIn(dictFile, std::ios::binary); if (!dictIn) { std::cerr << "Failed to open dictionary file: " << dictFile << std::endl; return entries; } - + // get the size of the postings file to determine the last posting's size std::ifstream postIn(postingsFile, std::ios::binary | std::ios::ate); uint64_t postingsFileSize = postIn.tellg(); postIn.close(); - + // read all dictionary entries while (dictIn.peek() != EOF) { DictEntry entry; - + // read term length uint32_t termLen; dictIn.read(reinterpret_cast(&termLen), sizeof(termLen)); if (!dictIn) break; - + // read term entry.term.resize(termLen); dictIn.read(&entry.term[0], termLen); if (!dictIn) break; - + // read offset dictIn.read(reinterpret_cast(&entry.offset), sizeof(entry.offset)); if (!dictIn) break; - + // read docFreq dictIn.read(reinterpret_cast(&entry.docFreq), sizeof(entry.docFreq)); if (!dictIn) break; - + entries.push_back(entry); } - + // calculate posting sizes using next entry's offset for (size_t i = 0; i < entries.size(); i++) { if (i + 1 < entries.size()) { @@ -71,42 +72,43 @@ std::vector readDictionary(const std::string& dictFile, const std::st entries[i].postingSize = postingsFileSize - entries[i].offset; } } - + return entries; } -std::vector readAndParsePosting(const std::string& postingsFile, uint64_t offset, uint64_t size) { +std::vector readAndParsePosting(const std::string& postingsFile, uint64_t offset, + uint64_t size) { std::vector entries; - + std::ifstream postIn(postingsFile, std::ios::binary); if (!postIn) { std::cerr << "Failed to open postings file: " << postingsFile << std::endl; return entries; } - + postIn.seekg(offset); - + uint64_t bytesRead = 0; while (bytesRead < size) { PostingEntry entry; - + // read docId postIn.read(reinterpret_cast(&entry.docId), sizeof(entry.docId)); bytesRead += sizeof(entry.docId); - + // read posCount uint32_t posCount; postIn.read(reinterpret_cast(&posCount), sizeof(posCount)); bytesRead += sizeof(posCount); - + // read positions entry.positions.resize(posCount); postIn.read(reinterpret_cast(entry.positions.data()), posCount * sizeof(uint32_t)); bytesRead += posCount * sizeof(uint32_t); - + entries.push_back(entry); } - + return entries; } @@ -114,13 +116,13 @@ void writePosting(std::ofstream& out, const std::vector& entries) for (const auto& entry : entries) { // write docId out.write(reinterpret_cast(&entry.docId), sizeof(entry.docId)); - + // write posCount uint32_t posCount = entry.positions.size(); out.write(reinterpret_cast(&posCount), sizeof(posCount)); - + // write positions - out.write(reinterpret_cast(entry.positions.data()), + out.write(reinterpret_cast(entry.positions.data()), posCount * sizeof(uint32_t)); } } @@ -128,28 +130,28 @@ void writePosting(std::ofstream& out, const std::vector& entries) std::vector mergePostings(const std::vector>& allPostings) { // use a map to merge postings by docId (automatically sorted) std::map> mergedMap; - + for (const auto& postings : allPostings) { for (const auto& entry : postings) { auto& positions = mergedMap[entry.docId]; positions.insert(positions.end(), entry.positions.begin(), entry.positions.end()); } } - + // convert map to vector and sort positions within each document std::vector result; result.reserve(mergedMap.size()); - + for (auto& [docId, positions] : mergedMap) { // sort positions for this document std::sort(positions.begin(), positions.end()); - + PostingEntry entry; entry.docId = docId; entry.positions = std::move(positions); result.push_back(std::move(entry)); } - + return result; } @@ -177,35 +179,29 @@ int main(int argc, char* argv[]) { std::string projectDir = projectRoot.string(); std::string dataDir = "/data"; - - const char* test_env = std::getenv("ENV"); // for integration tests, test with controlled and small dataset in test_data + + const char* test_env = std::getenv( + "ENV"); // for integration tests, test with controlled and small dataset in test_data if (test_env && std::string(test_env) == "TEST_ENV") { std::cout << "TEST ENVIRONMENT, merging index with test data." << std::endl; dataDir = "/test_data"; - } + } std::string partialIndexPostingsDir = projectDir + dataDir + "/partial_indices/postings"; std::string partialIndexDictDir = projectDir + dataDir + "/partial_indices/dictionaries"; std::string metadataDir = projectDir + dataDir + "/index"; - std::string outputDir = (projectRoot.parent_path() / "index" / "bin").string(); // put in parallel directory index/ where python code expects it - std::filesystem::create_directories(outputDir); - - // copy from building dir to output dir - if (std::filesystem::exists(metadataDir + "/metadata.bin")) std::filesystem::remove(outputDir + "/metadata.bin"); - std::filesystem::copy(metadataDir + "/metadata.bin", outputDir + "/metadata.bin"); - - if (std::filesystem::exists(outputDir + "/docstore.bin")) std::filesystem::remove(outputDir + "/docstore.bin"); - if (std::filesystem::exists(outputDir + "/docstore_offsets.bin")) std::filesystem::remove(outputDir + "/docstore_offsets.bin"); - std::filesystem::copy(projectDir + dataDir + "/docstore/docstore.bin", outputDir + "/docstore.bin"); - std::filesystem::copy(projectDir + dataDir + "/docstore/docstore_offsets.bin", outputDir + "/docstore_offsets.bin"); + std::string outputDir = + (projectRoot.parent_path() / "index" / "bin") + .string(); // put in parallel directory index/ where python code expects it // get number of partial indices size_t partialIndexCount = 0; while (true) { - std::string dictFile = partialIndexDictDir + "/dictionary_" + std::to_string(partialIndexCount) + ".bin"; + std::string dictFile = + partialIndexDictDir + "/dictionary_" + std::to_string(partialIndexCount) + ".bin"; std::ifstream testIn(dictFile); if (!testIn.is_open()) break; testIn.close(); - partialIndexCount++; + partialIndexCount++; } partialIndexCount++; std::cout << "Found " << partialIndexCount << " partial indices to merge.\n"; @@ -213,21 +209,19 @@ int main(int argc, char* argv[]) { auto allDicts = std::vector>(partialIndexCount); for (size_t i = 0; i < partialIndexCount - 1; i++) { std::string dictFile = partialIndexDictDir + "/dictionary_" + std::to_string(i) + ".bin"; - std::string postingsFile = partialIndexPostingsDir + "/postings_" + std::to_string(i) + ".bin"; + std::string postingsFile = + partialIndexPostingsDir + "/postings_" + std::to_string(i) + ".bin"; allDicts[i] = readDictionary(dictFile, postingsFile); } - allDicts[partialIndexCount - 1] = readDictionary( - partialIndexDictDir + "/dictionary_final.bin", - partialIndexPostingsDir + "/postings_final.bin" - ); + allDicts[partialIndexCount - 1] = + readDictionary(partialIndexDictDir + "/dictionary_final.bin", + partialIndexPostingsDir + "/postings_final.bin"); struct HeapEntry { std::string term; size_t dictIndex; size_t entryIndex; - bool operator>(const HeapEntry& other) const { - return term > other.term; - } + bool operator>(const HeapEntry& other) const { return term > other.term; } }; std::priority_queue, std::greater> minHeap; // initialize heap with first entry from each dictionary @@ -248,77 +242,82 @@ int main(int argc, char* argv[]) { while (!minHeap.empty()) { auto current = minHeap.top(); minHeap.pop(); - + const std::string& term = current.term; - + // collect all postings for this term from all dictionaries std::vector> postingsToMerge; std::vector docFreqs; - + // add the current entry's posting { size_t dictIndex = current.dictIndex; size_t entryIndex = current.entryIndex; const DictEntry& entry = allDicts[dictIndex][entryIndex]; docFreqs.push_back(entry.docFreq); - - std::string postingsFile = (dictIndex == partialIndexCount - 1) - ? partialIndexPostingsDir + "/postings_final.bin" - : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; - - postingsToMerge.push_back(readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); - + + std::string postingsFile = + (dictIndex == partialIndexCount - 1) + ? partialIndexPostingsDir + "/postings_final.bin" + : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; + + postingsToMerge.push_back( + readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); + if (entryIndex + 1 < allDicts[dictIndex].size()) { const DictEntry& nextEntry = allDicts[dictIndex][entryIndex + 1]; minHeap.push({nextEntry.term, dictIndex, entryIndex + 1}); } } - + // check if the next entries in the heap have the same term while (!minHeap.empty() && minHeap.top().term == term) { auto same = minHeap.top(); minHeap.pop(); - + size_t dictIndex = same.dictIndex; size_t entryIndex = same.entryIndex; const DictEntry& entry = allDicts[dictIndex][entryIndex]; docFreqs.push_back(entry.docFreq); - - std::string postingsFile = (dictIndex == partialIndexCount - 1) - ? partialIndexPostingsDir + "/postings_final.bin" - : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; - - postingsToMerge.push_back(readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); - + + std::string postingsFile = + (dictIndex == partialIndexCount - 1) + ? partialIndexPostingsDir + "/postings_final.bin" + : partialIndexPostingsDir + "/postings_" + std::to_string(dictIndex) + ".bin"; + + postingsToMerge.push_back( + readAndParsePosting(postingsFile, entry.offset, entry.postingSize)); + // push next entry from this dictionary into the heap if (entryIndex + 1 < allDicts[dictIndex].size()) { const DictEntry& nextEntry = allDicts[dictIndex][entryIndex + 1]; minHeap.push({nextEntry.term, dictIndex, entryIndex + 1}); - } + } } - + // merge postings properly (sorted by docId, with positions merged and sorted) std::vector mergedPostings = mergePostings(postingsToMerge); - + // calculate actual docFreq (number of unique documents) uint32_t totalDocFreq = mergedPostings.size(); - + // write posting data to final postings file writePosting(finalPostOut, mergedPostings); - + // calculate size of merged posting uint64_t postingSize = 0; for (const auto& entry : mergedPostings) { - postingSize += sizeof(uint32_t) + sizeof(uint32_t) + entry.positions.size() * sizeof(uint32_t); + postingSize += + sizeof(uint32_t) + sizeof(uint32_t) + entry.positions.size() * sizeof(uint32_t); } - + // write dictionary entry to final dictionary file uint32_t termLen = term.size(); finalDictOut.write(reinterpret_cast(&termLen), sizeof(termLen)); finalDictOut.write(term.data(), termLen); finalDictOut.write(reinterpret_cast(&finalOffset), sizeof(finalOffset)); finalDictOut.write(reinterpret_cast(&totalDocFreq), sizeof(totalDocFreq)); - + finalOffset += postingSize; termsProcessed++; @@ -326,14 +325,14 @@ int main(int argc, char* argv[]) { if (termsProcessed % 100000 == 0) { auto now = high_resolution_clock::now(); auto elapsed = std::chrono::duration_cast(now - start).count(); - std::cout << "Processed " << termsProcessed << " terms, elapsed time: " - << elapsed << "s" << std::endl; + std::cout << "Processed " << termsProcessed << " terms, elapsed time: " << elapsed + << "s" << std::endl; } } std::cout << "Merging completed successfully.\n"; - std::cout << "Time taken: " - << duration_cast(high_resolution_clock::now() - start).count() + std::cout << "Time taken: " + << duration_cast(high_resolution_clock::now() - start).count() << " seconds.\n"; finalPostOut.close(); finalDictOut.close(); diff --git a/src/backend/search_engine/models/index.py b/src/backend/search_engine/models/index.py index 064fdab..3ddeb9f 100644 --- a/src/backend/search_engine/models/index.py +++ b/src/backend/search_engine/models/index.py @@ -5,3 +5,4 @@ class SearchResult(BaseModel): document_id: int url: HttpUrl title: str + # snippet: str diff --git a/src/backend/search_engine/query/query_engine.py b/src/backend/search_engine/query/query_engine.py index b6dcf18..0328c21 100644 --- a/src/backend/search_engine/query/query_engine.py +++ b/src/backend/search_engine/query/query_engine.py @@ -103,7 +103,7 @@ def _bool_search(self, node: Node | None) -> PostingList: end = time.perf_counter() logger.debug( - f"Node={node.value}, Result docs={len(result.postings)}, " + f"Node={node.value!r}, Result docs={len(result.postings)}, " f"Execution time: {end - start:.6f} seconds" ) return result @@ -136,9 +136,7 @@ def search_results(self, limit: int = 10) -> list[SearchResult]: # positional phrase search logger.debug("Executing positional phrase query search...") normalized_tokens_no_quots = normalize_search_query(raw_query[1:-1]) - posting_lists = self._positional_phrase_search( - normalized_tokens_no_quots - ) + result = self._positional_phrase_search(normalized_tokens_no_quots) else: # any order -> create AND query logger.debug("Executing phrase query search...") @@ -146,22 +144,33 @@ def search_results(self, limit: int = 10) -> list[SearchResult]: logger.debug(f"Converted to AND query: {and_query}") qt.parse_query(and_query) logger.debug(f"Query tree: {qt.root}") - posting_lists = self._bool_search(qt.root) + result = self._bool_search(qt.root) else: logger.debug("Executing bool query search...") try: qt.parse_query(normalized_tokens) + # self.inverted_index.doc_store.query_terms = qt.unique_terms logger.debug(f"Query tree: {qt.root}") - posting_lists = self._bool_search(qt.root) + result = self._bool_search(qt.root) except InvalidOperatorError as e: logger.error(f"Invalid query syntax: {e}") raise - if posting_lists is None or len(posting_lists.postings) == 0: + if result is None or len(result.postings) == 0: return [] + logger.debug( + f"Found {len(result.postings)} results in {time.perf_counter() - start:.6f} seconds" + ) + + # resulting PostingList contains all matched documents + # tf and positions are empty (except positions for positional search) as they + # are term-specific and cannot be merged meaningfully here + + top_n_results = result # TODO will be done by BM25 ranking later + search_results = [] - for doc_id in posting_lists.postings[:limit]: + for doc_id in top_n_results.postings[:limit]: doc_data = self.inverted_index.doc_store.get(doc_id) if doc_data is None: continue @@ -170,12 +179,14 @@ def search_results(self, limit: int = 10) -> list[SearchResult]: if url is None: continue title = doc_data.title or "Untitled" + snippet = doc_data.snippet try: search_result = SearchResult( document_id=doc_id, url=url, # type: ignore[arg-type] title=title, + snippet=snippet, ) search_results.append(search_result) except Exception as e: diff --git a/src/backend/search_engine/query/query_preprocessing.py b/src/backend/search_engine/query/query_preprocessing.py index 03750f4..916e2fe 100644 --- a/src/backend/search_engine/query/query_preprocessing.py +++ b/src/backend/search_engine/query/query_preprocessing.py @@ -38,6 +38,7 @@ class QueryTree: def __init__(self) -> None: self._root: Node | None = None self._warnings_stack: list[ParenthesesWarning] = [] + self.unique_terms = set() # only positive (non-negated) terms for snippeting @property def root(self) -> Node | None: @@ -103,7 +104,7 @@ def _parse_query(self, tokens: list[str]) -> Node: return node # handle NOT and parentheses - def _parse_term(self, tokens: list[str]) -> Node: + def _parse_term(self, tokens: list[str], negated: bool = False) -> Node: if not tokens: raise ValueError("Unexpected end of tokens while parsing term") @@ -111,7 +112,7 @@ def _parse_term(self, tokens: list[str]) -> Node: if token in NOT: tokens.pop(0) - word_node = self._parse_term(tokens) + word_node = self._parse_term(tokens, negated=True) return Node(token, left=None, right=word_node) if token == "(": @@ -136,6 +137,8 @@ def _parse_term(self, tokens: list[str]) -> Node: # leaf node/actual word if token not in (AND | OR | NOT): + if not negated: + self.unique_terms.add(token) tokens.pop(0) return Node(token) diff --git a/src/frontend/Dockerfile b/src/frontend/Dockerfile new file mode 100644 index 0000000..862ee6d --- /dev/null +++ b/src/frontend/Dockerfile @@ -0,0 +1,11 @@ +FROM node:20-slim + +WORKDIR /app + +COPY package.json package-lock.json . + +RUN npm ci + +COPY . . + +CMD ["npm", "run", "dev"] \ No newline at end of file diff --git a/src/frontend/src/pages/Index.tsx b/src/frontend/src/pages/Index.tsx index 17d0762..a61c056 100644 --- a/src/frontend/src/pages/Index.tsx +++ b/src/frontend/src/pages/Index.tsx @@ -125,7 +125,7 @@ const Index = () => { try { const response = await fetch( - `http://127.0.0.1:8000/search?q=${encodeURIComponent(query)}&limit=${customLimit}` + `/search?q=${encodeURIComponent(query)}&limit=${customLimit}` ); if (!response.ok) { @@ -133,7 +133,7 @@ const Index = () => { try { const data = await response.json(); if (data.detail) errorMsg = data.detail; - } catch {} + } catch { } throw new Error(errorMsg); } @@ -272,11 +272,10 @@ const Index = () => { key={rpp} onClick={() => handleResultsPerPageChange(rpp)} disabled={isLoading} - className={`px-3 py-2 text-sm rounded-lg border transition-colors ${ - resultsPerPage === rpp + className={`px-3 py-2 text-sm rounded-lg border transition-colors ${resultsPerPage === rpp ? "bg-primary text-primary-foreground border-primary" : "border-input hover:bg-accent" - } disabled:opacity-50 disabled:cursor-not-allowed`} + } disabled:opacity-50 disabled:cursor-not-allowed`} > {rpp} @@ -361,13 +360,12 @@ const Index = () => { key={index} onClick={() => typeof page === "number" && handlePageChange(page)} disabled={page === "..."} - className={`min-w-[40px] h-10 rounded-lg font-medium transition-colors ${ - page === currentPage + className={`min-w-[40px] h-10 rounded-lg font-medium transition-colors ${page === currentPage ? "bg-primary text-primary-foreground" : page === "..." - ? "cursor-default" - : "hover:bg-accent" - }`} + ? "cursor-default" + : "hover:bg-accent" + }`} > {page} diff --git a/src/frontend/vite.config.ts b/src/frontend/vite.config.ts index da25c6d..af653ec 100644 --- a/src/frontend/vite.config.ts +++ b/src/frontend/vite.config.ts @@ -4,15 +4,25 @@ import path from "path"; import { componentTagger } from "lovable-tagger"; // https://vitejs.dev/config/ -export default defineConfig(({ mode }) => ({ - server: { - host: "::", - port: 8080, - }, - plugins: [react(), mode === "development" && componentTagger()].filter(Boolean), - resolve: { - alias: { - "@": path.resolve(__dirname, "./src"), +export default defineConfig(({ mode }) => { + const isDocker = process.env.ENV === "DOCKER"; + + return { + server: { + host: "::", + port: 8080, + proxy: { + "/search": { + target: isDocker ? "http://backend:8000" : "http://127.0.0.1:8000", + changeOrigin: true, + }, + }, + }, + plugins: [react(), mode === "development" && componentTagger()].filter(Boolean), + resolve: { + alias: { + "@": path.resolve(__dirname, "./src"), + }, }, - }, -})); + }; +}); diff --git a/tests/Dockerfile b/tests/Dockerfile index 4bcdab0..c571591 100644 --- a/tests/Dockerfile +++ b/tests/Dockerfile @@ -3,14 +3,12 @@ FROM python:3.13-slim # system dependencies # build-essential for packages using c extensions # git for packages installed from git -# zlib and libstemmer for CMake build of index_builder +# libstemmer for CMake build of index_builder RUN apt-get update && apt-get install -y \ build-essential \ git \ curl \ cmake \ - just \ - zlib1g-dev \ libstemmer-dev \ && rm -rf /var/lib/apt/lists/* @@ -21,18 +19,14 @@ ENV PATH="/root/.local/bin:$PATH" WORKDIR /app/src/backend # copy first to cache dependencies -COPY src/backend/pyproject.toml . -COPY src/backend/uv.lock . -COPY src/backend/bindings/ ./bindings/ +COPY pyproject.toml . +COPY uv.lock . +COPY bindings/ ./bindings/ RUN uv sync -COPY src/backend/ . +COPY . . -COPY tests/ /app/tests/ +RUN chmod +x /app/src/backend/tests/entrypoint.sh -COPY justfile /app/justfile - -RUN chmod +x /app/tests/entrypoint.sh - -ENTRYPOINT ["/app/tests/entrypoint.sh"] +ENTRYPOINT ["/app/src/backend/tests/entrypoint.sh"] diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 01b8ba5..c306e1f 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -5,4 +5,4 @@ services: dockerfile: tests/Dockerfile image: seekr-pytest-ci-env environment: - ENV: "TEST_ENV" \ No newline at end of file + ENV: "TEST_ENV" \ No newline at end of file diff --git a/tests/entrypoint.sh b/tests/entrypoint.sh index 71af52a..9ca2a33 100644 --- a/tests/entrypoint.sh +++ b/tests/entrypoint.sh @@ -1,6 +1,8 @@ #!/bin/bash set -e -just -f /app/justfile build-index +cd search_engine/scripts/ && \ +chmod +x build-index.sh && \ +./build-index.sh 1024 -1 -uv run pytest /app/tests \ No newline at end of file +uv run pytest /app/src/backend/tests \ No newline at end of file From d154d221d36421e7ff13bee7de15252f877881fc Mon Sep 17 00:00:00 2001 From: JanSkn Date: Tue, 23 Dec 2025 20:04:16 +0100 Subject: [PATCH 3/5] first stable snippeting version --- .dockerignore | 12 +- README.md | 35 +++- docker-compose.yml | 5 +- justfile | 15 +- local.sh | 2 +- src/backend/.dockerignore | 65 +++++++ src/backend/bindings/cpp_utils/_core.pyi | 1 + src/backend/bindings/utils.cpp | 172 +++++++++++++++--- .../index_builder/index_builder.cpp | 31 +--- .../search_engine/query/query_engine.py | 14 +- {tests => src/backend/tests}/Dockerfile | 3 +- src/backend/tests/conftest.py | 4 + .../backend/tests}/docker-compose.yml | 0 {tests => src/backend/tests}/entrypoint.sh | 4 +- {tests => src/backend/tests}/justfile | 0 .../test_index_builder/test_index_builder.py | 20 +- .../tests}/test_query/test_query_engine.py | 11 +- .../test_query/test_query_preprocessing.py | 0 src/frontend/src/components/SearchResults.tsx | 6 +- src/frontend/src/pages/Index.tsx | 60 ++++-- tests/conftest.py | 4 - 21 files changed, 344 insertions(+), 120 deletions(-) create mode 100644 src/backend/.dockerignore rename {tests => src/backend/tests}/Dockerfile (91%) create mode 100644 src/backend/tests/conftest.py rename {tests => src/backend/tests}/docker-compose.yml (100%) rename {tests => src/backend/tests}/entrypoint.sh (58%) rename {tests => src/backend/tests}/justfile (100%) rename {tests => src/backend/tests}/test_index_builder/test_index_builder.py (80%) rename {tests => src/backend/tests}/test_query/test_query_engine.py (98%) rename {tests => src/backend/tests}/test_query/test_query_preprocessing.py (100%) delete mode 100644 tests/conftest.py diff --git a/.dockerignore b/.dockerignore index 6d4e5a3..f7e4d21 100644 --- a/.dockerignore +++ b/.dockerignore @@ -55,11 +55,9 @@ dist/ **/*.tsv **/*.gz -# allow test data files as they are small and necessary for tests -!search_engine/index_builder/test_data/*.tsv -!search_engine/index_builder/test_data/*.gz -search_engine/index_builder/build/ -search_engine/index_builder/data/ -search_engine/index/bin/ -search_engine/models/neuspell-scrnn-probwordnoise/ +src/backend/search_engine/tests/ +src/backend/search_engine/index_builder/build/ +src/backend/search_engine/index_builder/data/ +src/backend/search_engine/index/bin/ +src/backend/search_engine/models/neuspell-scrnn-probwordnoise/ \ No newline at end of file diff --git a/README.md b/README.md index c5cd424..14d795e 100644 --- a/README.md +++ b/README.md @@ -20,15 +20,44 @@ Seekr consists of several core subsystems working together: * uv * Node.js * npm -* Docker (for containerized integration/unit tests) +* Docker (for containerized integration/unit tests & deployment) * LFS (downloading ML models from GitHub) * CMake (building and compiling the CPP components) * Just (command runner) ## Entrypoints +### Docker -### Build the Index -Before running the system, build the index with a memory limit: +```bash +just deploy +``` +Automated build process. Will download the dataset and build the index if +it does not exist yet. This preprocessing can take up to 2 hours. + +Afterwards, it spins up a frontend and a backend container. + +- Access search engine frontend via `http://localhost:8080`. +- API-only: `http://localhost:8000`. + + **Search Endpoint** + + **GET** `/search` + + Query parameters: + + | Parameter | Type | Description | + | --------- | ------ | ---------------------------------------------- | + | `q` | string | Search query (1–50 characters) | + | `limit` | int | Maximum number of results (1–500, default: 10) | + + +### Manual usage +Download the dataset: +```bash +cd src && uv run --project backend python -m backend.search_engine.scripts.download_dataset +``` + +Build the index with a memory limit: ```bash just build-index ``` diff --git a/docker-compose.yml b/docker-compose.yml index 37ee3c2..a87d7bb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,8 +6,9 @@ services: container_name: seekr-backend ports: - "8000:8000" # expose backend for FE-less API access - volumes: # bind mount index data - - ./src/backend/index_builder:/app/src/backend/index_builder + volumes: # bind mount large data files + - ./src/backend/search_engine/index_builder:/app/src/backend/search_engine/index_builder + - ./src/backend/search_engine/models/neuspell-scrnn-probwordnoise:/app/src/backend/search_engine/models/neuspell-scrnn-probwordnoise frontend: build: src/frontend container_name: seekr-frontend diff --git a/justfile b/justfile index 9839ec8..c423016 100644 --- a/justfile +++ b/justfile @@ -19,6 +19,15 @@ local *uvicorn-args: ./local.sh {{uvicorn-args}} deploy: + if [ ! -f src/backend/search_engine/index_builder/data/msmarco-docs.tsv ]; then + @echo "msmarco-docs.tsv file not found. Starting download..." + cd src && uv run --project backend python -m backend.search_engine.scripts.download_dataset + fi + if [ ! -e src/backend/search_engine/index/bin/* ]; then + @echo "Index binaries not found. Starting build process..." + just build-index + fi + @echo "Spinning up containers..." docker compose up -d build-index memory-limit="1024" max-docs="-1": @@ -41,8 +50,8 @@ generate-stubs: lint: @echo "Linting Python code..." - cd src/backend && uv run ruff check api/ search_engine/ ../../tests/ - cd src/backend && uv run ruff format --check --diff api/ search_engine/ ../../tests/ + cd src/backend && uv run ruff check api/ search_engine/ tests/ + cd src/backend && uv run ruff format --check --diff api/ search_engine/ tests/ @echo "Linting C++ code..." # only format-check instead of linting to avoid dependency-related failures clang-format --dry-run --Werror \ src/backend/bindings/utils.cpp \ @@ -51,7 +60,7 @@ lint: format: @echo "Formatting Python code..." - cd src/backend && uv run ruff format api/ search_engine/ ../../tests/ + cd src/backend && uv run ruff format api/ search_engine/ tests/ @echo "Formatting C++ code..." clang-format -i \ src/backend/bindings/utils.cpp \ diff --git a/local.sh b/local.sh index a410d7a..58e8bf4 100755 --- a/local.sh +++ b/local.sh @@ -6,7 +6,7 @@ export LOG_LEVEL=DEBUG PIDS=() cd src -uv run --project backend uvicorn backend.api.v1.app:app --host 127.0.0.1 --port 8000 "$@" & +uv run --project backend --refresh uvicorn backend.api.v1.app:app --host 127.0.0.1 --port 8000 "$@" & PIDS+=($!) echo "Uvicorn server started with PID ${PIDS[0]}" diff --git a/src/backend/.dockerignore b/src/backend/.dockerignore new file mode 100644 index 0000000..1499f19 --- /dev/null +++ b/src/backend/.dockerignore @@ -0,0 +1,65 @@ +# Python bytecode +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environments +.venv/ +venv/ +env/ +ENV/ +env.bak/ +venv.bak/ + +# Poetry / uv +.python-version +.uv/ +.python-version +.poetry/ +pdm.lock +.pdm-build/ + +# Test / coverage +.coverage +.coverage.* +.pytest_cache/ +htmlcov/ +nosetests.xml +coverage.xml +*.cover +*.py,cover + +# Distribution / packaging +build/ +dist/ +*.egg-info/ +*.egg +*.whl +*.tar.gz + +# IDE / editor +.vscode/ +.idea/ +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? + +# OS files +.DS_Store +Thumbs.db + +node_modules/ +dist/ + +**/*.tsv +**/*.gz +# allow test data files as they are small and necessary for tests +!search_engine/index_builder/test_data/*.tsv +!search_engine/index_builder/test_data/*.gz + +search_engine/index_builder/build/ +search_engine/index_builder/data/ +search_engine/index/bin/ +search_engine/models/neuspell-scrnn-probwordnoise/ \ No newline at end of file diff --git a/src/backend/bindings/cpp_utils/_core.pyi b/src/backend/bindings/cpp_utils/_core.pyi index 15822d2..56a5dad 100644 --- a/src/backend/bindings/cpp_utils/_core.pyi +++ b/src/backend/bindings/cpp_utils/_core.pyi @@ -34,6 +34,7 @@ class IndexAccessor: class InvertedIndex: def __init__(self, arg0: str) -> None: ... + def clear_cache(self) -> None: ... @property def doc_store(self) -> DocStore: ... diff --git a/src/backend/bindings/utils.cpp b/src/backend/bindings/utils.cpp index 7105546..37469dc 100644 --- a/src/backend/bindings/utils.cpp +++ b/src/backend/bindings/utils.cpp @@ -1,20 +1,70 @@ #include #include // for automatic conversion of STL containers -#include #include #include #include +#include #include #include #include #include #include +#include #include "libstemmer.h" namespace py = pybind11; +bool is_valid_utf8(const std::string& s) { + const unsigned char* b = reinterpret_cast(s.data()); + size_t n = s.size(); + + for (size_t i = 0; i < n; ++i) { + if (b[i] <= 0x7F) continue; + + size_t len = 0; + if ((b[i] & 0xE0) == 0xC0) + len = 1; + else if ((b[i] & 0xF0) == 0xE0) + len = 2; + else if ((b[i] & 0xF8) == 0xF0) + len = 3; + else + return false; + + if (i + len >= n) return false; + + for (size_t j = 1; j <= len; ++j) { + if ((b[i + j] & 0xC0) != 0x80) return false; + } + i += len; + } + return true; +} + +std::string latin1_to_utf8(const std::string& s) { + std::string out; + out.reserve(s.size() * 2); + + for (unsigned char c : s) { + if (c < 0x80) { + out.push_back(static_cast(c)); + } else { + out.push_back(static_cast(0xC0 | (c >> 6))); + out.push_back(static_cast(0x80 | (c & 0x3F))); + } + } + return out; +} + +std::string ensure_utf8(const std::string& s) { + if (is_valid_utf8(s)) { + return s; + } + return latin1_to_utf8(s); +} + struct SnowballStemmer { struct sb_stemmer* stemmer; SnowballStemmer() { stemmer = sb_stemmer_new("english", nullptr); } @@ -83,8 +133,10 @@ struct Metadata { std::unordered_map doc_lengths; void load(const std::string& path) { + std::cout << "Metadata: " << path << std::endl; std::ifstream in(path, std::ios::binary); if (!in.is_open()) throw std::runtime_error("Cannot open metadata file"); + if (in.peek() == EOF) throw std::runtime_error("Metadata file is empty"); in.read(reinterpret_cast(&num_docs), sizeof(num_docs)); in.read(reinterpret_cast(&avg_doc_length), sizeof(avg_doc_length)); @@ -151,6 +203,34 @@ PostingList read_posting_list(std::ifstream& in, uint64_t offset, uint32_t doc_f return pl; } +// Optimized helper to only read positions for one document for snippetting +std::vector scan_posting_list_for_doc(std::ifstream& in, uint64_t offset, + uint32_t doc_freq, uint32_t target_doc_id) { + in.seekg(offset); + + for (uint32_t i = 0; i < doc_freq; i++) { + uint32_t doc_id, pos_count; + in.read(reinterpret_cast(&doc_id), sizeof(doc_id)); + in.read(reinterpret_cast(&pos_count), sizeof(pos_count)); + + if (doc_id == target_doc_id) { + std::vector positions(pos_count); + in.read(reinterpret_cast(positions.data()), pos_count * sizeof(uint32_t)); + return positions; + } + + // if we passed the doc_id (list is sorted), it's not there + if (doc_id > target_doc_id) { + return {}; + } + + // skip positions for this doc + in.seekg(pos_count * sizeof(uint32_t), std::ios::cur); + } + + return {}; +} + struct DocInfo { std::string url; std::string title; @@ -201,11 +281,6 @@ class DocStore { uint64_t tsv_offset); std::string get_snippet(uint32_t doc_id, uint64_t tsv_offset); std::optional get(uint32_t doc_id); - std::optional get_tsv_offset(uint32_t doc_id) { - auto it = offsets.find(doc_id); - if (it == offsets.end()) return std::nullopt; - return it->second.tsv_offset; - } uint32_t size() const { return total_docs; } }; @@ -230,6 +305,10 @@ class InvertedIndex { DocStore doc_store; IndexAccessor index; + // Query cache for performance (especially snippeting of rare + common term combos) + std::unordered_map> cache; + void clear_cache() { cache.clear(); } + InvertedIndex(const std::string& base_path) : doc_store(this), index(this) { std::ifstream index_file(base_path + "/index.bin", std::ios::binary); while (true) { @@ -263,8 +342,14 @@ class InvertedIndex { // --- Docstore --- void DocStore::open(const std::string& dir_name) { data_in.open(dir_name + "/docstore.bin", std::ios::binary); - // TODO falscher pfad? - tsv_in.open(dir_name + "/../../index_builder/data/msmarco-docs.tsv", std::ios::binary); + std::string data_dir = "data"; + const char* test_env = std::getenv( + "ENV"); // for integration tests, test with controlled and small dataset in test_data + if (test_env && std::string(test_env) == "TEST_ENV") { + data_dir = "test_data"; + } + tsv_in.open(dir_name + "/../../index_builder/" + data_dir + "/msmarco-docs.tsv", + std::ios::binary); std::ifstream off(dir_name + "/docstore_offsets.bin", std::ios::binary); if (!data_in || !tsv_in || !off) throw std::runtime_error("Could not open docstore"); @@ -279,7 +364,7 @@ void DocStore::open(const std::string& dir_name) { if (!off.read(reinterpret_cast(&id), sizeof(id))) break; if (!off.read(reinterpret_cast(&off64), sizeof(off64))) break; - // if (!off.read(reinterpret_cast(&tsvOff), sizeof(tsvOff))) break; + if (!off.read(reinterpret_cast(&tsvOff), sizeof(tsvOff))) break; offsets[id] = {off64, tsvOff}; } @@ -407,9 +492,10 @@ std::string DocStore::load_snippet(uint32_t doc_id, } } - return snippet; + return ensure_utf8(snippet); } +// !!!! TODO MARK word with DocStore::SubsnippetResult DocStore::find_subsnippet(const std::vector& hits, int max_window_size, size_t required_term_count) { @@ -478,15 +564,26 @@ std::string DocStore::get_snippet(uint32_t doc_id, uint64_t tsv_offset) { // term: "bar"} ] std::vector hits; for (const auto& term : unique_terms) { + auto cache_it = parent->cache.find(term); + if (cache_it != parent->cache.end()) { + const auto& pl = *cache_it->second; + auto posIt = pl.positions.find(doc_id); + if (posIt != pl.positions.end()) { + for (uint32_t pos : posIt->second) hits.push_back(Hit{pos, term}); + } + continue; + } + auto termIt = parent->term_to_offset.find(term); if (termIt == parent->term_to_offset.end()) continue; auto docIt = parent->term_to_docfreq.find(term); - PostingList pl = read_posting_list(parent->postings_file, termIt->second, docIt->second); - auto posIt = pl.positions.find(doc_id); - if (posIt == pl.positions.end()) continue; + std::vector positions = + scan_posting_list_for_doc(parent->postings_file, termIt->second, docIt->second, doc_id); + + if (positions.empty()) continue; - for (uint32_t pos : posIt->second) hits.push_back(Hit{pos, term}); + for (uint32_t pos : positions) hits.push_back(Hit{pos, term}); } std::sort(hits.begin(), hits.end(), [](const Hit& a, const Hit& b) { return a.pos < b.pos; }); @@ -507,7 +604,8 @@ std::string DocStore::get_snippet(uint32_t doc_id, uint64_t tsv_offset) { return load_snippet(doc_id, snippet_windows, tsv_offset); } -std::optional DocStore::get(uint32_t doc_id) { +std::optional DocStore::get( + uint32_t doc_id) { // only load snippet when required as resource-intensive auto it = offsets.find(doc_id); if (it == offsets.end()) return std::nullopt; @@ -527,17 +625,26 @@ std::optional DocStore::get(uint32_t doc_id) { std::string title(title_len, '\0'); data_in.read(title.data(), title_len); - // std::string snippet = get_snippet(doc_id, tsv_offset); - - return DocInfo{url, title, "snippet"}; + std::string snippet = get_snippet(doc_id, tsv_offset); + return DocInfo{url, title, snippet}; } // -------------------- std::optional IndexAccessor::get(const std::string& term) { + // Check cache + auto cache_it = parent->cache.find(term); + if (cache_it != parent->cache.end()) { + return *cache_it->second; + } + auto it = parent->term_to_offset.find(term); if (it == parent->term_to_offset.end()) return std::nullopt; uint32_t doc_freq = parent->term_to_docfreq.at(term); PostingList pl = read_posting_list(parent->postings_file, it->second, doc_freq); + + // Add to cache + parent->cache[term] = std::make_shared(pl); + return pl; } @@ -595,6 +702,7 @@ PostingList positional_intersect(const PostingList& pl1, const PostingList& pl2, if (!valid_positions.empty()) { result.postings.push_back(doc_id); + result.term_frequencies[doc_id] = valid_positions.size(); result.positions[doc_id] = std::move(valid_positions); } } @@ -643,16 +751,19 @@ PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std: const size_t n1 = p1.size(); const size_t n2 = p2.size(); - if (mode == "AND") { - std::vector intersected; - intersected.reserve(std::min(n1, n2)); // most likely + std::vector result_postings; + result_postings.reserve(std::min(n1, n2)); // most likely + std::unordered_map result_tf; + + if (mode == "AND") { while (i < n1 && j < n2) { uint32_t d1 = p1[i]; uint32_t d2 = p2[j]; if (d1 == d2) { - intersected.push_back(d1); + result_postings.push_back(d1); + result_tf[d1] = tf1.at(d1) + tf2.at(d2); i++; j++; @@ -673,7 +784,7 @@ PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std: } } - PostingList out(intersected, {}, {}); + PostingList out(result_postings, result_tf, {}); out.build_skip_pointers(); return out; } @@ -690,13 +801,16 @@ PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std: if (x == y) { merged.push_back(x); + result_tf[x] = tf1.at(x) + tf2.at(y); a++; b++; } else if (x < y) { merged.push_back(x); + result_tf[x] = tf1.at(x); a++; } else { merged.push_back(y); + result_tf[y] = tf2.at(y); b++; } } @@ -705,14 +819,16 @@ PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std: while (a < n1) { uint32_t x = p1[a++]; merged.push_back(x); + result_tf[x] = tf1.at(x); } while (b < n2) { uint32_t y = p2[b++]; merged.push_back(y); + result_tf[y] = tf2.at(y); } - PostingList out(merged, {}, {}); + PostingList out(merged, result_tf, {}); out.build_skip_pointers(); return out; } @@ -731,11 +847,12 @@ PostingList find_docs(const PostingList& pl1, const PostingList& pl2, const std: if (b == n2 || p2[b] != x) { diff.push_back(x); + result_tf[x] = tf1.at(x); } a++; } - PostingList out(diff, {}, {}); + PostingList out(diff, result_tf, {}); out.build_skip_pointers(); return out; } @@ -784,7 +901,7 @@ PYBIND11_MODULE(_core, m) { py::class_(m, "DocStore") .def("get", &DocStore::get, py::arg("doc_id")) - .def("get_tsv_offset", &DocStore::get_tsv_offset, py::arg("doc_id")); + .def_readwrite("query_terms", &DocStore::query_terms); py::class_(m, "IndexAccessor").def("get", &IndexAccessor::get, py::arg("term")); @@ -792,5 +909,6 @@ PYBIND11_MODULE(_core, m) { .def(py::init()) .def_readonly("index", &InvertedIndex::index) .def_readonly("metadata", &InvertedIndex::metadata) - .def_readonly("doc_store", &InvertedIndex::doc_store); + .def_readonly("doc_store", &InvertedIndex::doc_store) + .def("clear_cache", &InvertedIndex::clear_cache); } diff --git a/src/backend/search_engine/index_builder/index_builder.cpp b/src/backend/search_engine/index_builder/index_builder.cpp index 5de6633..43221ec 100644 --- a/src/backend/search_engine/index_builder/index_builder.cpp +++ b/src/backend/search_engine/index_builder/index_builder.cpp @@ -267,12 +267,10 @@ int main(int argc, char* argv[]) { std::string outputDir = (projectRoot.parent_path() / "index" / "bin") .string(); // put in parallel directory index/ where python code expects it - std::string metadataDir = outputDir + "/metadata.bin"; std::string docstoreBase = outputDir; std::filesystem::create_directories(partialIndexPostingsDir); std::filesystem::create_directories(partialIndexDictDir); - std::filesystem::create_directories(metadataDir); std::filesystem::create_directories(outputDir); Tokenizer tokenizer; @@ -373,31 +371,8 @@ int main(int argc, char* argv[]) { const char* contentStart = line.data() + pos3 + 1; size_t contentLen = line.size() - pos3 - 1; - // process title - tokenizer.tokenize(titleStart, titleLen, [&](std::string&& term, int position) { - docTermCount++; - - uint32_t termId; - auto it = termDictionary.find(term); - if (it == termDictionary.end()) { - termId = termDictionary.size(); - memoryBytes += sizeof(uint32_t) + term.size(); - termDictionary.emplace(std::move(term), termId); - } else { - termId = it->second; - } - - auto& postings = termPostings[termId]; - if (postings.empty() || postings.back().docId != docId) { - postings.push_back({docId, {}}); - postings.back().positions.reserve(8); - postings.back().positions.push_back(position); - memoryBytes += sizeof(Posting) + sizeof(int); - } else { - postings.back().positions.push_back(position); - memoryBytes += sizeof(int); - } - }); + tokenizer.tokenize(titleStart, titleLen, + [&](std::string&& term, int position) { docTermCount++; }); // process content (positions continue from title) tokenizer.tokenize(contentStart, contentLen, [&](std::string&& term, int position) { @@ -450,7 +425,7 @@ int main(int argc, char* argv[]) { } double avgDocLength = numDocs > 0 ? static_cast(totalTerms) / numDocs : 0.0; - std::string metadataFile = metadataDir + "/metadata.bin"; + std::string metadataFile = outputDir + "/metadata.bin"; std::ofstream metaOut(metadataFile, std::ios::binary); if (!metaOut) { std::cerr << "Failed to open metadata file for writing\n"; diff --git a/src/backend/search_engine/query/query_engine.py b/src/backend/search_engine/query/query_engine.py index d73755e..8cbdc61 100755 --- a/src/backend/search_engine/query/query_engine.py +++ b/src/backend/search_engine/query/query_engine.py @@ -137,6 +137,9 @@ def search_results(self, limit: int = 10) -> SearchResults: correction = repl(self.corrector, raw_query) if not qt._has_operators(normalized_tokens): + self.inverted_index.doc_store.query_terms = list( + set(normalized_tokens) + ) # needed for snippets if (raw_query.startswith('"') and raw_query.endswith('"')) or ( raw_query.startswith("'") and raw_query.endswith("'") ): @@ -156,7 +159,9 @@ def search_results(self, limit: int = 10) -> SearchResults: logger.debug("Executing bool query search...") try: qt.parse_query(normalized_tokens) - # self.inverted_index.doc_store.query_terms = qt.unique_terms + self.inverted_index.doc_store.query_terms = ( + qt.unique_terms + ) # needed for snippets logger.debug(f"Query tree: {qt.root}") result = self._bool_search(qt.root) except InvalidOperatorError as e: @@ -170,10 +175,6 @@ def search_results(self, limit: int = 10) -> SearchResults: f"Found {len(result.postings)} results in {time.perf_counter() - start:.6f} seconds" ) - # resulting PostingList contains all matched documents - # tf and positions are empty (except positions for positional search) as they - # are term-specific and cannot be merged meaningfully here - top_n_results = result # TODO will be done by BM25 ranking later search_results = [] @@ -205,4 +206,7 @@ def search_results(self, limit: int = 10) -> SearchResults: f"Returned {len(search_results)} results. " f"Total execution time: {end - start:.6f} seconds" ) + # clear cache to free memory + self.inverted_index.clear_cache() + return SearchResults(search_results=search_results, correction=correction) diff --git a/tests/Dockerfile b/src/backend/tests/Dockerfile similarity index 91% rename from tests/Dockerfile rename to src/backend/tests/Dockerfile index c571591..43badbd 100644 --- a/tests/Dockerfile +++ b/src/backend/tests/Dockerfile @@ -9,6 +9,7 @@ RUN apt-get update && apt-get install -y \ git \ curl \ cmake \ + just \ libstemmer-dev \ && rm -rf /var/lib/apt/lists/* @@ -29,4 +30,4 @@ COPY . . RUN chmod +x /app/src/backend/tests/entrypoint.sh -ENTRYPOINT ["/app/src/backend/tests/entrypoint.sh"] +ENTRYPOINT ["/app/src/backend/tests/entrypoint.sh"] \ No newline at end of file diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py new file mode 100644 index 0000000..5c1cbce --- /dev/null +++ b/src/backend/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) diff --git a/tests/docker-compose.yml b/src/backend/tests/docker-compose.yml similarity index 100% rename from tests/docker-compose.yml rename to src/backend/tests/docker-compose.yml diff --git a/tests/entrypoint.sh b/src/backend/tests/entrypoint.sh similarity index 58% rename from tests/entrypoint.sh rename to src/backend/tests/entrypoint.sh index 9ca2a33..b71f084 100644 --- a/tests/entrypoint.sh +++ b/src/backend/tests/entrypoint.sh @@ -2,7 +2,7 @@ set -e cd search_engine/scripts/ && \ -chmod +x build-index.sh && \ -./build-index.sh 1024 -1 + chmod +x build-index.sh && \ + ./build-index.sh 1024 -1 uv run pytest /app/src/backend/tests \ No newline at end of file diff --git a/tests/justfile b/src/backend/tests/justfile similarity index 100% rename from tests/justfile rename to src/backend/tests/justfile diff --git a/tests/test_index_builder/test_index_builder.py b/src/backend/tests/test_index_builder/test_index_builder.py similarity index 80% rename from tests/test_index_builder/test_index_builder.py rename to src/backend/tests/test_index_builder/test_index_builder.py index 69b852a..0715043 100644 --- a/tests/test_index_builder/test_index_builder.py +++ b/src/backend/tests/test_index_builder/test_index_builder.py @@ -124,16 +124,16 @@ def test_real_index_docstore(): doc_store = inverted_index.doc_store expected_docs = { - 0: DocInfo("http://example.com/0", "Title One"), - 1: DocInfo("http://example.com/1", "Title Two"), - 2: DocInfo("http://example.com/2", "Title Three"), - 3: DocInfo("http://example.com/3", "Title Four"), - 4: DocInfo("http://example.com/4", "Title Five"), - 5: DocInfo("http://example.com/5", "Title Six"), - 6: DocInfo("http://example.com/6", "Title Seven"), - 7: DocInfo("http://example.com/7", "Title Eight"), - 8: DocInfo("http://example.com/8", "Title Nine"), - 9: DocInfo("http://example.com/9", "Title Ten"), + 0: DocInfo("http://example.com/0", "Title One", snippet=""), + 1: DocInfo("http://example.com/1", "Title Two", snippet=""), + 2: DocInfo("http://example.com/2", "Title Three", snippet=""), + 3: DocInfo("http://example.com/3", "Title Four", snippet=""), + 4: DocInfo("http://example.com/4", "Title Five", snippet=""), + 5: DocInfo("http://example.com/5", "Title Six", snippet=""), + 6: DocInfo("http://example.com/6", "Title Seven", snippet=""), + 7: DocInfo("http://example.com/7", "Title Eight", snippet=""), + 8: DocInfo("http://example.com/8", "Title Nine", snippet=""), + 9: DocInfo("http://example.com/9", "Title Ten", snippet=""), } for doc_id, doc_info in expected_docs.items(): diff --git a/tests/test_query/test_query_engine.py b/src/backend/tests/test_query/test_query_engine.py similarity index 98% rename from tests/test_query/test_query_engine.py rename to src/backend/tests/test_query/test_query_engine.py index 510f3be..f9bfe9c 100644 --- a/tests/test_query/test_query_engine.py +++ b/src/backend/tests/test_query/test_query_engine.py @@ -638,9 +638,9 @@ def test_search_results_basic( positions={1: [0], 2: [1], 3: [2]}, ) mock_inverted_index.doc_store.get.side_effect = [ - DocInfo(url="http://example.com/1", title="Doc 1"), - DocInfo(url="http://example.com/2", title="Doc 2"), - DocInfo(url="http://example.com/3", title="Doc 3"), + DocInfo(url="http://example.com/1", title="Doc 1", snippet=""), + DocInfo(url="http://example.com/2", title="Doc 2", snippet=""), + DocInfo(url="http://example.com/3", title="Doc 3", snippet=""), ] mock_spell_corrector.correct.return_value = "test" @@ -664,7 +664,7 @@ def test_search_results_with_limit( positions={i: [0] for i in range(1, 6)}, ) mock_inverted_index.doc_store.get.side_effect = [ - DocInfo(url=f"http://example.com/{i}", title=f"Doc {i}") + DocInfo(url=f"http://example.com/{i}", title=f"Doc {i}", snippet="") for i in range(1, 6) ] @@ -750,7 +750,8 @@ def test_full_workflow( ] mock_inverted_index.doc_store.get.side_effect = [ - DocInfo(url=f"http://example.com/{i}", title=f"Doc {i}") for i in [2, 3, 5] + DocInfo(url=f"http://example.com/{i}", title=f"Doc {i}", snippet="") + for i in [2, 3, 5] ] term1 = Node(value="term1") diff --git a/tests/test_query/test_query_preprocessing.py b/src/backend/tests/test_query/test_query_preprocessing.py similarity index 100% rename from tests/test_query/test_query_preprocessing.py rename to src/backend/tests/test_query/test_query_preprocessing.py diff --git a/src/frontend/src/components/SearchResults.tsx b/src/frontend/src/components/SearchResults.tsx index 1319a34..ac1d4b1 100644 --- a/src/frontend/src/components/SearchResults.tsx +++ b/src/frontend/src/components/SearchResults.tsx @@ -4,7 +4,7 @@ import { Card } from "@/components/ui/card"; interface SearchResult { title: string; url: string; - description?: string; + snippet?: string; } interface SearchResultsProps { @@ -37,9 +37,9 @@ export const SearchResults = ({ results }: SearchResultsProps) => {

{result.url}

- {result.description && ( + {(result.snippet) && (

- {result.description} +

)} diff --git a/src/frontend/src/pages/Index.tsx b/src/frontend/src/pages/Index.tsx index 52174b6..7b975da 100644 --- a/src/frontend/src/pages/Index.tsx +++ b/src/frontend/src/pages/Index.tsx @@ -10,7 +10,7 @@ import seekrLogo from "@/assets/seekr-logo.png"; interface SearchResult { title: string; url: string; - description?: string; + snippet?: string; } const Index = () => { @@ -128,7 +128,7 @@ const Index = () => { try { const response = await fetch( - `http://127.0.0.1:8000/search?q=${encodeURIComponent(query)}&limit=${customLimit}` + `/search?q=${encodeURIComponent(query)}&limit=${customLimit}` ); if (!response.ok) { @@ -283,8 +283,8 @@ const Index = () => { onClick={() => handleResultsPerPageChange(rpp)} disabled={isLoading} className={`px-3 py-2 text-sm rounded-lg border transition-colors ${resultsPerPage === rpp - ? "bg-primary text-primary-foreground border-primary" - : "border-input hover:bg-accent" + ? "bg-primary text-primary-foreground border-primary" + : "border-input hover:bg-accent" } disabled:opacity-50 disabled:cursor-not-allowed`} > {rpp} @@ -370,21 +370,43 @@ const Index = () => { {/* Pagination */} {totalPages > 1 && ( -
- {Array.from({ length: totalPages }, (_, index) => index + 1).map( - (pageNum) => ( - - ) - )} +
+ {/* Previous Button */} + + + {/* Page Numbers */} + {getPageNumbers().map((pageNum, index) => ( + + ))} + + {/* Next Button */} +
)} diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 8d9f898..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) From 6167b2ccdc615aabe75f6fc09148ecd7215410e9 Mon Sep 17 00:00:00 2001 From: JanSkn Date: Mon, 29 Dec 2025 15:07:57 +0100 Subject: [PATCH 4/5] improved snippets and deployment --- deploy.sh | 15 ++ docker-compose.yml | 5 +- justfile | 12 +- src/backend/Dockerfile | 3 + src/backend/bindings/cpp_utils/_core.pyi | 2 + src/backend/bindings/utils.cpp | 169 ++++++++++++----- .../index_builder/index_builder.cpp | 5 +- .../index_builder/test_data/msmarco-docs.tsv | 3 +- .../search_engine/query/query_engine.py | 2 +- src/backend/tests/test_index.py | 175 ++++++++++++++++++ .../test_index_builder/test_index_builder.py | 141 -------------- 11 files changed, 327 insertions(+), 205 deletions(-) create mode 100755 deploy.sh create mode 100644 src/backend/tests/test_index.py delete mode 100644 src/backend/tests/test_index_builder/test_index_builder.py diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..273799d --- /dev/null +++ b/deploy.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -euo pipefail + +if [ ! -f src/backend/search_engine/index_builder/data/msmarco-docs.tsv ]; then + echo "msmarco-docs.tsv file not found. Starting download..." + cd src && uv run --project backend python -m backend.search_engine.scripts.download_dataset +fi + +if [ -z "$(ls -A src/backend/search_engine/index/bin/ 2>/dev/null)" ]; then + echo "Index binaries not found. Starting build process..." + just build-index +fi + +echo "Spinning up containers..." +docker compose up -d \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index a87d7bb..657fefc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,5 @@ version: "3.9" -name: "Seekr. Search Engine" +name: "seekr-search-engine" services: backend: build: src/backend @@ -7,7 +7,8 @@ services: ports: - "8000:8000" # expose backend for FE-less API access volumes: # bind mount large data files - - ./src/backend/search_engine/index_builder:/app/src/backend/search_engine/index_builder + - ./src/backend/search_engine/index/bin:/app/src/backend/search_engine/index/bin + - ./src/backend/search_engine/index_builder/data/msmarco-docs.tsv:/app/src/backend/search_engine/index_builder/data/msmarco-docs.tsv - ./src/backend/search_engine/models/neuspell-scrnn-probwordnoise:/app/src/backend/search_engine/models/neuspell-scrnn-probwordnoise frontend: build: src/frontend diff --git a/justfile b/justfile index c423016..7856c9f 100644 --- a/justfile +++ b/justfile @@ -19,16 +19,8 @@ local *uvicorn-args: ./local.sh {{uvicorn-args}} deploy: - if [ ! -f src/backend/search_engine/index_builder/data/msmarco-docs.tsv ]; then - @echo "msmarco-docs.tsv file not found. Starting download..." - cd src && uv run --project backend python -m backend.search_engine.scripts.download_dataset - fi - if [ ! -e src/backend/search_engine/index/bin/* ]; then - @echo "Index binaries not found. Starting build process..." - just build-index - fi - @echo "Spinning up containers..." - docker compose up -d + chmod +x deploy.sh && \ + ./deploy.sh build-index memory-limit="1024" max-docs="-1": cd src/backend/search_engine/scripts/ && \ diff --git a/src/backend/Dockerfile b/src/backend/Dockerfile index dd7d502..b93b94f 100644 --- a/src/backend/Dockerfile +++ b/src/backend/Dockerfile @@ -12,6 +12,7 @@ RUN apt-get update && apt-get install -y \ libstemmer-dev \ && rm -rf /var/lib/apt/lists/* +RUN curl -LsSf https://astral.sh/uv/install.sh | sh ENV PATH="/root/.local/bin:$PATH" # set workdir to where pyproject.toml is located for uv @@ -26,4 +27,6 @@ RUN uv sync COPY . . +ENV PYTHONPATH="/app/src:$PYTHONPATH" + CMD ["uv", "run", "uvicorn", "api.v1.app:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/src/backend/bindings/cpp_utils/_core.pyi b/src/backend/bindings/cpp_utils/_core.pyi index 56a5dad..f197165 100644 --- a/src/backend/bindings/cpp_utils/_core.pyi +++ b/src/backend/bindings/cpp_utils/_core.pyi @@ -28,6 +28,8 @@ class Metadata: class DocStore: def get(self, doc_id: int) -> DocInfo | None: ... + def get_tsv_offset(self, doc_id: int) -> int | None: + ... class IndexAccessor: def get(self, term: str) -> PostingList | None: ... diff --git a/src/backend/bindings/utils.cpp b/src/backend/bindings/utils.cpp index 37469dc..c080115 100644 --- a/src/backend/bindings/utils.cpp +++ b/src/backend/bindings/utils.cpp @@ -1,3 +1,4 @@ +#include #include #include // for automatic conversion of STL containers @@ -5,12 +6,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include "libstemmer.h" @@ -268,7 +269,8 @@ class DocStore { }; SubsnippetResult find_subsnippet(const std::vector& hits, int max_window_size, - size_t required_term_count); + size_t required_term_count, + std::vector>& term_positions); public: std::vector query_terms; @@ -277,10 +279,12 @@ class DocStore { void open(const std::string& dir_name); std::string load_snippet(uint32_t doc_id, - std::vector>& snippet_windows, + std::vector>& snippet_window_borders, + std::vector>& term_positions, uint64_t tsv_offset); std::string get_snippet(uint32_t doc_id, uint64_t tsv_offset); std::optional get(uint32_t doc_id); + std::optional get_tsv_offset(uint32_t doc_id); uint32_t size() const { return total_docs; } }; @@ -370,12 +374,12 @@ void DocStore::open(const std::string& dir_name) { } } -std::string DocStore::load_snippet(uint32_t doc_id, - std::vector>& snippet_windows, - uint64_t tsv_offset) { - if (snippet_windows.empty()) return ""; +std::string DocStore::load_snippet( + uint32_t doc_id, std::vector>& snippet_window_borders, + std::vector>& term_positions, uint64_t tsv_offset) { + if (snippet_window_borders.empty()) return ""; - std::sort(snippet_windows.begin(), snippet_windows.end()); + std::sort(snippet_window_borders.begin(), snippet_window_borders.end()); tsv_in.clear(); tsv_in.seekg(tsv_offset); std::string line; @@ -402,7 +406,7 @@ std::string DocStore::load_snippet(uint32_t doc_id, size_t i = content_start; size_t window_idx = 0; - if (snippet_windows[0].first > 0) { + if (snippet_window_borders[0].first > 0) { snippet += "... "; } @@ -410,17 +414,22 @@ std::string DocStore::load_snippet(uint32_t doc_id, auto is_sentence_end = [](char c) { return c == '.' || c == '!' || c == '?'; }; // Calculate threshold for last window (last 10%) - uint32_t last_window_start = snippet_windows.back().first; - uint32_t last_window_end = snippet_windows.back().second; + uint32_t last_window_start = snippet_window_borders.back().first; + uint32_t last_window_end = snippet_window_borders.back().second; uint32_t last_window_size = last_window_end - last_window_start + 1; uint32_t last_window_threshold = last_window_end - (last_window_size / 10); bool stopped_at_sentence_end = false; - while (i < len && window_idx < snippet_windows.size()) { + while (i < len && window_idx < snippet_window_borders.size()) { + std::set highlight_positions; + if (window_idx < term_positions.size()) { + highlight_positions = std::set(term_positions[window_idx].begin(), + term_positions[window_idx].end()); + } // --- determine word --- size_t word_start = i; - while (word_start < len && !std::isalpha(static_cast(line[word_start]))) { + while (word_start < len && !std::isalnum(static_cast(line[word_start]))) { word_start++; } std::string separator = line.substr(i, word_start - i); @@ -429,7 +438,7 @@ std::string DocStore::load_snippet(uint32_t doc_id, break; } size_t word_end = word_start; - while (word_end < len && std::isalpha(static_cast(line[word_end]))) { + while (word_end < len && std::isalnum(static_cast(line[word_end]))) { word_end++; } std::string word = line.substr(word_start, word_end - word_start); @@ -437,19 +446,24 @@ std::string DocStore::load_snippet(uint32_t doc_id, // check if current word is in relevant window // skip windows that are already passed - while (window_idx < snippet_windows.size() && - current_word_pos > snippet_windows[window_idx].second) { + while (window_idx < snippet_window_borders.size() && + current_word_pos > snippet_window_borders[window_idx].second) { window_idx++; - if (window_idx < snippet_windows.size()) { + if (window_idx < snippet_window_borders.size()) { snippet += " ... "; } } - if (window_idx < snippet_windows.size()) { - uint32_t w_start = snippet_windows[window_idx].first; - uint32_t w_end = snippet_windows[window_idx].second; + if (window_idx < snippet_window_borders.size()) { + uint32_t w_start = snippet_window_borders[window_idx].first; + uint32_t w_end = snippet_window_borders[window_idx].second; if (current_word_pos >= w_start && current_word_pos <= w_end) { + // highlight a found term + if (highlight_positions.count(current_word_pos)) { + word = "" + word + ""; + } + if (current_word_pos == w_start) { snippet += word; } else { @@ -457,7 +471,7 @@ std::string DocStore::load_snippet(uint32_t doc_id, } // check if we're in the last window and in its last 10% - bool is_last_window = (window_idx == snippet_windows.size() - 1); + bool is_last_window = (window_idx == snippet_window_borders.size() - 1); if (is_last_window && current_word_pos >= last_window_threshold && current_word_pos < w_end) { // look for sentence-ending punctuation after this word @@ -484,9 +498,9 @@ std::string DocStore::load_snippet(uint32_t doc_id, } // check if there is more text after the snippets (only if we didn't stop at sentence end) - if (!stopped_at_sentence_end && window_idx >= snippet_windows.size()) { + if (!stopped_at_sentence_end && window_idx >= snippet_window_borders.size()) { size_t check = i; - while (check < len && !std::isalpha(static_cast(line[check]))) check++; + while (check < len && !std::isalnum(static_cast(line[check]))) check++; if (check < len) { snippet += " ..."; } @@ -495,17 +509,17 @@ std::string DocStore::load_snippet(uint32_t doc_id, return ensure_utf8(snippet); } -// !!!! TODO MARK word with -DocStore::SubsnippetResult DocStore::find_subsnippet(const std::vector& hits, - int max_window_size, - size_t required_term_count) { +DocStore::SubsnippetResult DocStore::find_subsnippet( + const std::vector& hits, int max_window_size, size_t required_term_count, + std::vector>& term_positions) { SubsnippetResult result{}; result.start = 0; result.end = 0; if (hits.empty()) return result; - std::unordered_map window_term_count; + std::unordered_map + window_term_count; // count term occurance in the window uint32_t left = 0; uint32_t best_start = hits[0].pos; @@ -519,7 +533,7 @@ DocStore::SubsnippetResult DocStore::find_subsnippet(const std::vector& hit for (uint32_t right = 0; right < hits.size(); ++right) { window_term_count[hits[right].term]++; - // shrink window if too large + // shrink window if too large, adjust term counts while (hits[right].pos - hits[left].pos > max_window_size) { auto& c = window_term_count[hits[left].term]; if (--c == 0) window_term_count.erase(hits[left].term); @@ -545,18 +559,34 @@ DocStore::SubsnippetResult DocStore::find_subsnippet(const std::vector& hit result.start = best_start; result.end = best_end; - // collect remaining hits outside the best window + // collect left hits (only relevant if called by first window for second window) result.remaining_hits.reserve(hits.size()); for (uint32_t i = 0; i < hits.size(); ++i) { - if (i < best_left_idx || i > best_right_idx) result.remaining_hits.push_back(hits[i]); + if (i > best_right_idx) result.remaining_hits.push_back(hits[i]); + } + + // for highlighting positions bold + std::vector window_positions; + + for (const auto& hit : hits) { + if (hit.pos >= best_start && hit.pos <= best_end) { + window_positions.push_back(hit.pos); + } } + term_positions.push_back(window_positions); + return result; } // total snippet length: max. MAX_WINDOW_SIZE x 2 + 1 or 2x "..." std::string DocStore::get_snippet(uint32_t doc_id, uint64_t tsv_offset) { - uint32_t MAX_WINDOW_SIZE = 15; // num of words PER subsnippet + int MAX_WINDOW_SIZE = 15; // max. num of words PER subsnippet + + if (query_terms.empty()) { + throw std::runtime_error( + "Set query_terms (not empty): InvertedIndex().doc_store.query_terms = ..."); + } std::set unique_terms(query_terms.begin(), query_terms.end()); // e.g. @@ -566,42 +596,84 @@ std::string DocStore::get_snippet(uint32_t doc_id, uint64_t tsv_offset) { for (const auto& term : unique_terms) { auto cache_it = parent->cache.find(term); if (cache_it != parent->cache.end()) { - const auto& pl = *cache_it->second; - auto posIt = pl.positions.find(doc_id); - if (posIt != pl.positions.end()) { - for (uint32_t pos : posIt->second) hits.push_back(Hit{pos, term}); - } - continue; + const auto& pl = *cache_it->second; + auto posIt = pl.positions.find(doc_id); + if (posIt != pl.positions.end()) { + for (uint32_t pos : posIt->second) hits.push_back(Hit{pos, term}); + } + continue; } + // --- term not found in cache --- auto termIt = parent->term_to_offset.find(term); if (termIt == parent->term_to_offset.end()) continue; auto docIt = parent->term_to_docfreq.find(term); std::vector positions = scan_posting_list_for_doc(parent->postings_file, termIt->second, docIt->second, doc_id); - if (positions.empty()) continue; for (uint32_t pos : positions) hits.push_back(Hit{pos, term}); + // --------------------------------- } std::sort(hits.begin(), hits.end(), [](const Hit& a, const Hit& b) { return a.pos < b.pos; }); + std::vector> term_positions; // positions of the search terms in window i + // example: [[1, 3, 5], [2, 5]], in window 0 (index 0): term A and B at pos. 1, 3, 5, etc. // create first optimal snippet - SubsnippetResult first_snippet = find_subsnippet(hits, MAX_WINDOW_SIZE, unique_terms.size()); + SubsnippetResult first_snippet = + find_subsnippet(hits, MAX_WINDOW_SIZE, unique_terms.size(), term_positions); // can be one if all terms fit into MAX_WINDOW_SIZE, or at most 2 for remaining terms // not more than 2 for readability - std::vector> snippet_windows; - snippet_windows.push_back({first_snippet.start, first_snippet.end}); + std::vector> snippet_window_borders; + snippet_window_borders.push_back({first_snippet.start, first_snippet.end}); if (!first_snippet.remaining_hits.empty()) { - SubsnippetResult second_snippet = - find_subsnippet(first_snippet.remaining_hits, MAX_WINDOW_SIZE, unique_terms.size()); - snippet_windows.push_back({second_snippet.start, second_snippet.end}); + SubsnippetResult second_snippet = find_subsnippet( + first_snippet.remaining_hits, MAX_WINDOW_SIZE, unique_terms.size(), term_positions); + + if (second_snippet.end > 0) { + snippet_window_borders.push_back({second_snippet.start, second_snippet.end}); + } } - return load_snippet(doc_id, snippet_windows, tsv_offset); + // enhance context if windows are too small + int total_budget = MAX_WINDOW_SIZE * 2; + + if (snippet_window_borders.size() == 1) { + // only one window --> can consume 2x the size + auto& [start, end] = snippet_window_borders[0]; + int window_len = end - start; + int remaining = total_budget - window_len; + int left_context = remaining / 2; + int right_context = remaining - left_context; + + start = (start >= left_context) ? start - left_context : 0; + end = end + right_context; + + } else { + // 2 windows --> half each + int budget_per_window = total_budget / snippet_window_borders.size(); + + for (auto& [start, end] : snippet_window_borders) { + int window_len = end - start; + int remaining = budget_per_window - window_len; + int left_context = remaining / 2; + int right_context = remaining - left_context; + + start = (start >= left_context) ? start - left_context : 0; + end = end + right_context; + } + } + + return load_snippet(doc_id, snippet_window_borders, term_positions, tsv_offset); +} + +std::optional DocStore::get_tsv_offset(uint32_t doc_id) { + auto it = offsets.find(doc_id); + if (it == offsets.end()) return std::nullopt; + return it->second.tsv_offset; } std::optional DocStore::get( @@ -641,10 +713,10 @@ std::optional IndexAccessor::get(const std::string& term) { if (it == parent->term_to_offset.end()) return std::nullopt; uint32_t doc_freq = parent->term_to_docfreq.at(term); PostingList pl = read_posting_list(parent->postings_file, it->second, doc_freq); - + // Add to cache parent->cache[term] = std::make_shared(pl); - + return pl; } @@ -901,6 +973,7 @@ PYBIND11_MODULE(_core, m) { py::class_(m, "DocStore") .def("get", &DocStore::get, py::arg("doc_id")) + .def("get_tsv_offset", &DocStore::get_tsv_offset, py::arg("doc_id")) .def_readwrite("query_terms", &DocStore::query_terms); py::class_(m, "IndexAccessor").def("get", &IndexAccessor::get, py::arg("term")); diff --git a/src/backend/search_engine/index_builder/index_builder.cpp b/src/backend/search_engine/index_builder/index_builder.cpp index 43221ec..0c57d5e 100644 --- a/src/backend/search_engine/index_builder/index_builder.cpp +++ b/src/backend/search_engine/index_builder/index_builder.cpp @@ -127,13 +127,13 @@ class Tokenizer { size_t i = 0; while (i < len) { - while (i < len && !std::isalpha(static_cast(text[i]))) { + while (i < len && !std::isalnum(static_cast(text[i]))) { i++; } if (i >= len) break; tokenBuffer.clear(); - while (i < len && std::isalpha(static_cast(text[i]))) { + while (i < len && std::isalnum(static_cast(text[i]))) { tokenBuffer.push_back(std::tolower(static_cast(text[i]))); i++; } @@ -371,6 +371,7 @@ int main(int argc, char* argv[]) { const char* contentStart = line.data() + pos3 + 1; size_t contentLen = line.size() - pos3 - 1; + // do not store title positions as it would mix with body positions tokenizer.tokenize(titleStart, titleLen, [&](std::string&& term, int position) { docTermCount++; }); diff --git a/src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv b/src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv index 69243b4..0224892 100644 --- a/src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv +++ b/src/backend/search_engine/index_builder/test_data/msmarco-docs.tsv @@ -7,4 +7,5 @@ D5 http://example.com/5 Title Six beta D6 http://example.com/6 Title Seven gamma delta D7 http://example.com/7 Title Eight alpha delta delta D8 http://example.com/8 Title Nine beta gamma -D9 http://example.com/9 Title Ten alpha beta gamma delta \ No newline at end of file +D9 http://example.com/9 Title Ten alpha beta gamma delta +D10 http://example.com/10 Title Eleven Hello - what is up? These are the first 15 words of the body. Now, we have a big gap until the next window. This gap should be very long to see the effect of 2 potential different windows. Here begins the second snippet. \ No newline at end of file diff --git a/src/backend/search_engine/query/query_engine.py b/src/backend/search_engine/query/query_engine.py index 8cbdc61..10e7212 100755 --- a/src/backend/search_engine/query/query_engine.py +++ b/src/backend/search_engine/query/query_engine.py @@ -208,5 +208,5 @@ def search_results(self, limit: int = 10) -> SearchResults: ) # clear cache to free memory self.inverted_index.clear_cache() - + return SearchResults(search_results=search_results, correction=correction) diff --git a/src/backend/tests/test_index.py b/src/backend/tests/test_index.py new file mode 100644 index 0000000..bff358c --- /dev/null +++ b/src/backend/tests/test_index.py @@ -0,0 +1,175 @@ +import pytest + +from cpp_utils import DocInfo, normalize_search_query +from backend.search_engine.index.index_loader import get_index + +EXPECTED = { + "alpha": { + "postings": [0, 1, 4, 7, 9], + "tf": { + 0: 2, + 1: 3, + 4: 1, + 7: 1, + 9: 1, + }, + "pos": { + 0: [0, 2], + 1: [0, 1, 2], + 4: [1], + 7: [0], + 9: [0], + }, + }, + "beta": { + "postings": [0, 2, 4, 5, 8, 9], + "tf": { + 0: 2, + 2: 2, + 4: 1, + 5: 1, + 8: 1, + 9: 1, + }, + "pos": { + 0: [1, 3], + 2: [0, 1], + 4: [2], + 5: [0], + 8: [0], + 9: [1], + }, + }, + "gamma": { + "postings": [2, 3, 6, 8, 9], + "tf": { + 2: 2, + 3: 3, + 6: 1, + 8: 1, + 9: 1, + }, + "pos": { + 2: [2, 3], + 3: [0, 1, 2], + 6: [0], + 8: [1], + 9: [2], + }, + }, + "delta": { + "postings": [3, 4, 6, 7, 9], + "tf": { + 3: 1, + 4: 1, + 6: 1, + 7: 2, + 9: 1, + }, + "pos": { + 3: [3], + 4: [0], + 6: [1], + 7: [1, 2], + 9: [3], + }, + }, +} + + +def test_index_metadata(): + inverted_index = get_index() + metadata = inverted_index.metadata + + expected_num_docs = 11 + # counting title + body, excluding stop words + expected_doc_lengths = { + 0: 6, + 1: 5, + 2: 6, + 3: 6, + 4: 5, + 5: 3, + 6: 4, + 7: 5, + 8: 4, + 9: 6, + 10: 31, + } + expected_avg_length = sum(expected_doc_lengths.values()) / expected_num_docs + assert metadata.num_docs == expected_num_docs + assert metadata.avg_doc_length == expected_avg_length + assert metadata.doc_lengths == expected_doc_lengths + + for doc_id, length in expected_doc_lengths.items(): + assert metadata.get_doc_length(doc_id) == length + + +@pytest.mark.parametrize("term", ["alpha", "beta", "gamma", "delta"]) +def test_real_index_postings(term): + inverted_index = get_index() + result = inverted_index.index.get(term) + + assert result.postings == EXPECTED[term]["postings"] + assert result.term_frequencies == EXPECTED[term]["tf"] + assert result.positions == EXPECTED[term]["pos"] + + assert result.postings == sorted(result.postings) + for doc_id, pos in result.positions.items(): + assert len(pos) == result.term_frequencies[doc_id] + + +class TestDocStore: + def test_docstore_metadata(self): + inverted_index = get_index() + doc_store = inverted_index.doc_store + + expected_docs = { + 0: DocInfo("http://example.com/0", "Title One", snippet=""), + 1: DocInfo("http://example.com/1", "Title Two", snippet=""), + 2: DocInfo("http://example.com/2", "Title Three", snippet=""), + 3: DocInfo("http://example.com/3", "Title Four", snippet=""), + 4: DocInfo("http://example.com/4", "Title Five", snippet=""), + 5: DocInfo("http://example.com/5", "Title Six", snippet=""), + 6: DocInfo("http://example.com/6", "Title Seven", snippet=""), + 7: DocInfo("http://example.com/7", "Title Eight", snippet=""), + 8: DocInfo("http://example.com/8", "Title Nine", snippet=""), + 9: DocInfo("http://example.com/9", "Title Ten", snippet=""), + 10: DocInfo("http://example.com/10", "Title Eleven", snippet=""), + } + + doc_store.query_terms = ["-1"] # not needed here, but must be set + for doc_id, doc_info in expected_docs.items(): + assert doc_store.get(doc_id).url == doc_info.url + assert doc_store.get(doc_id).title == doc_info.title + + @pytest.mark.parametrize( + "query, expected_snippet", + [ + ( + "hello", + "Hello - what is up? These are the first 15 words of the body. Now, we have ...", + ), + ( + "snippet", + "... very long to see the effect of 2 potential different windows. Here begins the second snippet", + ), + ( + "hello body", + "Hello - what is up? These are the first 15 words of the body. Now, we have a big gap until the next ...", + ), + ( + "hello snippet", + "Hello - what is up? These are the first 15 ... potential different windows. Here begins the second snippet", + ), + ], + ) + def test_docstore_snippets(self, query: str, expected_snippet: str): + inverted_index = get_index() + doc_store = inverted_index.doc_store + + DOC_ID = 10 + doc_store.query_terms = normalize_search_query(query) + snippet = doc_store.get(DOC_ID).snippet + + assert snippet == expected_snippet diff --git a/src/backend/tests/test_index_builder/test_index_builder.py b/src/backend/tests/test_index_builder/test_index_builder.py deleted file mode 100644 index 0715043..0000000 --- a/src/backend/tests/test_index_builder/test_index_builder.py +++ /dev/null @@ -1,141 +0,0 @@ -import pytest - -from cpp_utils import DocInfo -from backend.search_engine.index.index_loader import get_index - -EXPECTED = { - "alpha": { - "postings": [0, 1, 4, 7, 9], - "tf": { - 0: 2, - 1: 3, - 4: 1, - 7: 1, - 9: 1, - }, - "pos": { - 0: [0, 2], - 1: [0, 1, 2], - 4: [1], - 7: [0], - 9: [0], - }, - }, - "beta": { - "postings": [0, 2, 4, 5, 8, 9], - "tf": { - 0: 2, - 2: 2, - 4: 1, - 5: 1, - 8: 1, - 9: 1, - }, - "pos": { - 0: [1, 3], - 2: [0, 1], - 4: [2], - 5: [0], - 8: [0], - 9: [1], - }, - }, - "gamma": { - "postings": [2, 3, 6, 8, 9], - "tf": { - 2: 2, - 3: 3, - 6: 1, - 8: 1, - 9: 1, - }, - "pos": { - 2: [2, 3], - 3: [0, 1, 2], - 6: [0], - 8: [1], - 9: [2], - }, - }, - "delta": { - "postings": [3, 4, 6, 7, 9], - "tf": { - 3: 1, - 4: 1, - 6: 1, - 7: 2, - 9: 1, - }, - "pos": { - 3: [3], - 4: [0], - 6: [1], - 7: [1, 2], - 9: [3], - }, - }, -} - - -def test_real_index_metadata(): - inverted_index = get_index() - metadata = inverted_index.metadata - - expected_num_docs = 10 - # counting title + body - expected_doc_lengths = { - 0: 6, - 1: 5, - 2: 6, - 3: 6, - 4: 5, - 5: 3, - 6: 4, - 7: 5, - 8: 4, - 9: 6, - } - expected_avg_length = sum(expected_doc_lengths.values()) / expected_num_docs - - assert metadata.num_docs == expected_num_docs - assert metadata.avg_doc_length == expected_avg_length - assert metadata.doc_lengths == expected_doc_lengths - - for doc_id, length in expected_doc_lengths.items(): - assert metadata.get_doc_length(doc_id) == length - - -@pytest.mark.parametrize("term", ["alpha", "beta", "gamma", "delta"]) -def test_real_index_postings(term): - inverted_index = get_index() - result = inverted_index.index.get(term) - - assert result.postings == EXPECTED[term]["postings"] - assert result.term_frequencies == EXPECTED[term]["tf"] - assert result.positions == EXPECTED[term]["pos"] - - assert result.postings == sorted(result.postings) - for doc_id, pos in result.positions.items(): - assert len(pos) == result.term_frequencies[doc_id] - - -def test_real_index_docstore(): - inverted_index = get_index() - doc_store = inverted_index.doc_store - - expected_docs = { - 0: DocInfo("http://example.com/0", "Title One", snippet=""), - 1: DocInfo("http://example.com/1", "Title Two", snippet=""), - 2: DocInfo("http://example.com/2", "Title Three", snippet=""), - 3: DocInfo("http://example.com/3", "Title Four", snippet=""), - 4: DocInfo("http://example.com/4", "Title Five", snippet=""), - 5: DocInfo("http://example.com/5", "Title Six", snippet=""), - 6: DocInfo("http://example.com/6", "Title Seven", snippet=""), - 7: DocInfo("http://example.com/7", "Title Eight", snippet=""), - 8: DocInfo("http://example.com/8", "Title Nine", snippet=""), - 9: DocInfo("http://example.com/9", "Title Ten", snippet=""), - } - - for doc_id, doc_info in expected_docs.items(): - assert doc_store.get(doc_id).url == doc_info.url - assert doc_store.get(doc_id).title == doc_info.title From 6f8507b04ad53c902b9a771bc6167891a2b41d66 Mon Sep 17 00:00:00 2001 From: JanSkn Date: Mon, 29 Dec 2025 15:12:19 +0100 Subject: [PATCH 5/5] fix: mypy type hint --- src/backend/search_engine/query/query_preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/backend/search_engine/query/query_preprocessing.py b/src/backend/search_engine/query/query_preprocessing.py index 916e2fe..4c3b863 100644 --- a/src/backend/search_engine/query/query_preprocessing.py +++ b/src/backend/search_engine/query/query_preprocessing.py @@ -38,7 +38,9 @@ class QueryTree: def __init__(self) -> None: self._root: Node | None = None self._warnings_stack: list[ParenthesesWarning] = [] - self.unique_terms = set() # only positive (non-negated) terms for snippeting + self.unique_terms: set[str] = ( + set() + ) # only positive (non-negated) terms for snippeting @property def root(self) -> Node | None: