diff --git a/.gitmodules b/.gitmodules index 5ead8fc9..79ef9dd4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "crates/burn/upstream"] path = crates/burn/upstream - url = https://github.com/tracel-ai/burn.git + url = https://github.com/AdaWorldAPI/burn.git diff --git a/Cargo.lock b/Cargo.lock index 843f927f..518cd8d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,15 +8,6 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" -[[package]] -name = "addr2line" -version = "0.25.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" -dependencies = [ - "gimli 0.32.3", -] - [[package]] name = "adler2" version = "2.0.1" @@ -38,15 +29,6 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anes" version = "0.1.6" @@ -92,54 +74,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "ash" -version = "0.38.0+1.3.281" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb44936d800fea8f016d7f2311c6a4f97aebd5dc86f09906139ec848cf3a46f" -dependencies = [ - "libloading 0.8.9", -] - -[[package]] -name = "async-channel" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" -dependencies = [ - "concurrent-queue", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "atomic_float" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" - [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "backtrace" -version = "0.3.76" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-link", -] - [[package]] name = "base64" version = "0.21.7" @@ -158,31 +98,6 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" -[[package]] -name = "bincode" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" -dependencies = [ - "serde", - "unty", -] - -[[package]] -name = "bit-set" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ddef2995421ab6a5c779542c81ee77c115206f4ad9d5a8e05f4ff49716a3dd" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" - [[package]] name = "bitflags" version = "1.3.2" @@ -257,15 +172,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc119b6761ce8b063102502af49043051f81a9bdf242ae06d12e9ea0d92b727a" -[[package]] -name = "block2" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdeb9d870516001442e364c5220d3574d2da8dc765554b4a617230d33fa58ef5" -dependencies = [ - "objc2", -] - [[package]] name = "bumpalo" version = "3.20.2" @@ -275,97 +181,6 @@ dependencies = [ "allocator-api2", ] -[[package]] -name = "burn" -version = "0.1.0" -dependencies = [ - "atomic_float", - "blas-src", - "burn-backend", - "burn-ir", - "burn-std", - "bytemuck", - "bytes", - "const-random", - "itertools 0.14.0", - "libm", - "macerator", - "matrixmultiply", - "ndarray", - "num-traits", - "openblas-src", - "paste", - "rand 0.10.0", - "rayon", - "seq-macro", - "serde", -] - -[[package]] -name = "burn-backend" -version = "0.21.0-pre.2" -source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" -dependencies = [ - "burn-std", - "bytemuck", - "cubecl", - "derive-new", - "enumset", - "hashbrown 0.16.1", - "num-traits", - "portable-atomic-util", - "rand 0.10.0", - "rand_distr 0.6.0", - "serde", - "spin", - "thiserror", -] - -[[package]] -name = "burn-ir" -version = "0.21.0-pre.2" -source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" -dependencies = [ - "burn-backend", - "hashbrown 0.16.1", - "serde", -] - -[[package]] -name = "burn-std" -version = "0.21.0-pre.2" -source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" -dependencies = [ - "bytemuck", - "bytes", - "cubecl-common", - "cubecl-zspace", - "half", - "num-traits", - "serde", - "smallvec", -] - -[[package]] -name = "bytemuck" -version = "1.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "byteorder" version = "1.5.0" @@ -377,9 +192,6 @@ name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" -dependencies = [ - "portable-atomic", -] [[package]] name = "cast" @@ -412,23 +224,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - -[[package]] -name = "chacha20" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" -dependencies = [ - "cfg-if", - "cpufeatures", - "rand_core 0.10.0", -] - [[package]] name = "ciborium" version = "0.2.2" @@ -490,61 +285,12 @@ dependencies = [ "cc", ] -[[package]] -name = "codespan-reporting" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" -dependencies = [ - "serde", - "termcolor", - "unicode-width", -] - -[[package]] -name = "concurrent-queue" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "const-random" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" -dependencies = [ - "const-random-macro", -] - -[[package]] -name = "const-random-macro" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom 0.2.17", - "once_cell", - "tiny-keccak", -] - [[package]] name = "constant_time_eq" version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" -[[package]] -name = "convert_case" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "core-foundation" version = "0.10.1" @@ -599,11 +345,11 @@ dependencies = [ "cranelift-control", "cranelift-entity", "cranelift-isle", - "gimli 0.31.1", + "gimli", "hashbrown 0.14.5", "log", "regalloc2", - "rustc-hash 2.1.2", + "rustc-hash", "serde", "smallvec", "target-lexicon", @@ -794,591 +540,81 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] -name = "cubecl" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "cubecl-core", - "cubecl-cuda", - "cubecl-ir", - "cubecl-runtime", - "cubecl-wgpu", - "half", -] +name = "defmac" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5592fca31e96d8a748d03080b58be78c5383617aa4bd89e69f30607d8769891" [[package]] -name = "cubecl-common" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" dependencies = [ - "backtrace", - "bincode", - "bytemuck", - "bytes", - "cfg-if", - "cfg_aliases", - "derive-new", - "derive_more", - "dirs", - "embassy-futures", - "embassy-time", - "float4", - "float8", - "futures-lite", - "half", - "hashbrown 0.16.1", - "log", - "num-traits", - "oneshot", - "parking_lot", - "portable-atomic", - "portable-atomic-util", - "rand 0.10.0", - "sanitize-filename", - "serde", - "serde_bytes", - "serde_json", - "spin", - "tynm", - "wasm-bindgen-futures", - "web-time", - "xxhash-rust", + "pem-rfc7468", + "zeroize", ] [[package]] -name = "cubecl-core" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" dependencies = [ - "bitflags 2.11.0", - "bytemuck", - "cubecl-common", - "cubecl-ir", - "cubecl-macros", - "cubecl-runtime", - "cubecl-zspace", - "derive-new", - "derive_more", - "enumset", - "float-ord", - "half", - "hashbrown 0.16.1", - "log", - "num-traits", - "paste", - "serde", - "serde_json", - "variadics_please", + "dirs-sys", ] [[package]] -name = "cubecl-cpp" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-opt", - "cubecl-runtime", - "derive-new", - "half", - "itertools 0.14.0", - "log", + "libc", + "option-ext", + "redox_users", + "windows-sys 0.61.2", ] [[package]] -name = "cubecl-cuda" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "bytemuck", - "cubecl-common", - "cubecl-core", - "cubecl-cpp", - "cubecl-runtime", - "cudarc", - "derive-new", - "half", - "log", - "serde", -] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] -name = "cubecl-ir" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "cubecl-common", - "cubecl-macros-internal", - "derive-new", - "derive_more", - "enumset", - "float-ord", - "fnv", - "foldhash 0.2.0", - "half", - "hashbrown 0.16.1", - "num-traits", - "portable-atomic", - "serde", - "variadics_please", -] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] -name = "cubecl-macros" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ - "cubecl-common", - "darling 0.23.0", - "derive-new", - "ident_case", - "inflections", - "prettyplease", - "proc-macro2", - "quote", - "syn", + "libc", + "windows-sys 0.61.2", ] [[package]] -name = "cubecl-macros-internal" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "darling 0.23.0", - "proc-macro2", - "quote", - "syn", -] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" [[package]] -name = "cubecl-opt" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "cubecl-common", - "cubecl-core", - "cubecl-ir", - "float-ord", - "log", - "num", - "petgraph", - "smallvec", - "stable-vec", - "type-map", -] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "cubecl-runtime" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "async-channel", - "bytemuck", - "cfg-if", - "cfg_aliases", - "cubecl-common", - "cubecl-ir", - "cubecl-zspace", - "derive-new", - "derive_more", - "dirs", - "enumset", - "hashbrown 0.16.1", - "log", - "md5", - "serde", - "serde_json", - "spin", - "thiserror", - "toml", - "variadics_please", - "wasm-bindgen-futures", - "web-time", -] - -[[package]] -name = "cubecl-wgpu" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "async-channel", - "bytemuck", - "cfg-if", - "cfg_aliases", - "cubecl-common", - "cubecl-core", - "cubecl-ir", - "cubecl-runtime", - "derive-new", - "derive_more", - "half", - "hashbrown 0.16.1", - "log", - "sanitize-filename", - "wgpu", -] - -[[package]] -name = "cubecl-zspace" -version = "0.10.0-pre.2" -source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" -dependencies = [ - "derive-new", - "serde", - "smallvec", -] - -[[package]] -name = "cudarc" -version = "0.19.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" -dependencies = [ - "libloading 0.9.0", -] - -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core 0.20.11", - "darling_macro 0.20.11", -] - -[[package]] -name = "darling" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" -dependencies = [ - "darling_core 0.21.3", - "darling_macro 0.21.3", -] - -[[package]] -name = "darling" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" -dependencies = [ - "darling_core 0.23.0", - "darling_macro 0.23.0", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_core" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "darling_core" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" -dependencies = [ - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core 0.20.11", - "quote", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" -dependencies = [ - "darling_core 0.21.3", - "quote", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" -dependencies = [ - "darling_core 0.23.0", - "quote", - "syn", -] - -[[package]] -name = "defmac" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5592fca31e96d8a748d03080b58be78c5383617aa4bd89e69f30607d8769891" - -[[package]] -name = "der" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" -dependencies = [ - "pem-rfc7468", - "zeroize", -] - -[[package]] -name = "derive-new" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "derive_more" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" -dependencies = [ - "derive_more-impl", -] - -[[package]] -name = "derive_more-impl" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" -dependencies = [ - "convert_case", - "proc-macro2", - "quote", - "rustc_version", - "syn", - "unicode-xid", -] - -[[package]] -name = "dirs" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" -dependencies = [ - "libc", - "option-ext", - "redox_users", - "windows-sys 0.61.2", -] - -[[package]] -name = "dispatch2" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e0e367e4e7da84520dedcac1901e4da967309406d1e51017ae1abfb97adbd38" -dependencies = [ - "bitflags 2.11.0", - "objc2", -] - -[[package]] -name = "dlib" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab8ecd87370524b461f8557c119c405552c396ed91fc0a8eec68679eab26f94a" -dependencies = [ - "libloading 0.8.9", -] - -[[package]] -name = "document-features" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" -dependencies = [ - "litrs", -] - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - -[[package]] -name = "embassy-futures" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01" - -[[package]] -name = "embassy-time" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "592b0c143ec626e821d4d90da51a2bd91d559d6c442b7c74a47d368c9e23d97a" -dependencies = [ - "cfg-if", - "critical-section", - "document-features", - "embassy-time-driver", - "embedded-hal 0.2.7", - "embedded-hal 1.0.0", - "embedded-hal-async", - "futures-core", -] - -[[package]] -name = "embassy-time-driver" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee71af1b3a0deaa53eaf2d39252f83504c853646e472400b763060389b9fcc9" -dependencies = [ - "document-features", -] - -[[package]] -name = "embedded-hal" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35949884794ad573cf46071e41c9b60efb0cb311e3ca01f7af807af1debc66ff" -dependencies = [ - "nb 0.1.3", - "void", -] - -[[package]] -name = "embedded-hal" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" - -[[package]] -name = "embedded-hal-async" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4c685bbef7fe13c3c6dd4da26841ed3980ef33e841cddfa15ce8a8fb3f1884" -dependencies = [ - "embedded-hal 1.0.0", -] - -[[package]] -name = "enumset" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25b07a8dfbbbfc0064c0a6bdf9edcf966de6b1c33ce344bdeca3b41615452634" -dependencies = [ - "enumset_derive", - "serde", -] - -[[package]] -name = "enumset_derive" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43e744e4ea338060faee68ed933e46e722fb7f3617e722a5772d7e856d8b3ce" -dependencies = [ - "darling 0.21.3", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "equivalent" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" - -[[package]] -name = "errno" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" -dependencies = [ - "libc", - "windows-sys 0.61.2", -] - -[[package]] -name = "event-listener" -version = "5.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" -dependencies = [ - "event-listener", - "pin-project-lite", -] - -[[package]] -name = "fallible-iterator" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" - -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - -[[package]] -name = "filetime" -version = "0.2.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", @@ -1391,12 +627,6 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" -[[package]] -name = "fixedbitset" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" - [[package]] name = "flate2" version = "1.1.9" @@ -1407,45 +637,12 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "float-ord" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" - -[[package]] -name = "float4" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a5404bf31d22893d61cf24d4dda149d8e6b2ff07601c3cb3be651031f61a4ed" - -[[package]] -name = "float8" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2d1f04709a8ac06e8e8042875a3c466cc4832d3c1a18dbcb9dba3c6e83046bc" -dependencies = [ - "half", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - [[package]] name = "foldhash" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" -[[package]] -name = "foldhash" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" - [[package]] name = "foreign-types" version = "0.3.2" @@ -1469,49 +666,6 @@ dependencies = [ "libm", ] -[[package]] -name = "futures-core" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" - -[[package]] -name = "futures-io" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" - -[[package]] -name = "futures-lite" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" -dependencies = [ - "fastrand", - "futures-core", - "futures-io", - "parking", - "pin-project-lite", -] - -[[package]] -name = "futures-task" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" - -[[package]] -name = "futures-util" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" -dependencies = [ - "futures-core", - "futures-task", - "pin-project-lite", - "slab", -] - [[package]] name = "getrandom" version = "0.2.17" @@ -1560,89 +714,14 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "gimli" -version = "0.32.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" - -[[package]] -name = "gl_generator" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a95dfc23a2b4a9a2f5ab41d194f8bfda3cabec42af4e39f08c339eb2a0c124d" -dependencies = [ - "khronos_api", - "log", - "xml-rs", -] - -[[package]] -name = "glow" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29038e1c483364cc6bb3cf78feee1816002e127c331a1eec55a4d202b9e1adb5" -dependencies = [ - "js-sys", - "slotmap", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "glutin_wgl_sys" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" -dependencies = [ - "gl_generator", -] - -[[package]] -name = "gpu-allocator" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51255ea7cfaadb6c5f1528d43e92a82acb2b96c43365989a28b2d44ee38f8795" -dependencies = [ - "ash", - "hashbrown 0.16.1", - "log", - "presser", - "thiserror", - "windows", -] - -[[package]] -name = "gpu-descriptor" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b89c83349105e3732062a895becfc71a8f921bb71ecbbdd8ff99263e3b53a0ca" -dependencies = [ - "bitflags 2.11.0", - "gpu-descriptor-types", - "hashbrown 0.15.5", -] - -[[package]] -name = "gpu-descriptor-types" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "half" version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ - "bytemuck", "cfg-if", "crunchy", - "num-traits", - "serde", "zerocopy", ] @@ -1658,7 +737,7 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash 0.1.5", + "foldhash", ] [[package]] @@ -1666,13 +745,6 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" -dependencies = [ - "allocator-api2", - "equivalent", - "foldhash 0.2.0", - "serde", - "serde_core", -] [[package]] name = "heck" @@ -1686,12 +758,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "hexf-parse" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" - [[package]] name = "http" version = "1.4.0" @@ -1714,12 +780,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "indexmap" version = "2.13.0" @@ -1732,12 +792,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "inflections" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a257582fdcde896fd96463bf2d40eefea0580021c0712a0e2b028b60b47a837a" - [[package]] name = "is-terminal" version = "0.4.17" @@ -1767,78 +821,22 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" -[[package]] -name = "jni-sys" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" -dependencies = [ - "jni-sys 0.4.1", -] - -[[package]] -name = "jni-sys" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" -dependencies = [ - "jni-sys-macros", -] - -[[package]] -name = "jni-sys-macros" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" -dependencies = [ - "quote", - "syn", -] - [[package]] name = "js-sys" version = "0.3.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc4c90f45aa2e6eacbe8645f77fdea542ac97a494bcd117a67df9ff4d611f995" dependencies = [ - "cfg-if", - "futures-util", "once_cell", "wasm-bindgen", ] -[[package]] -name = "khronos-egl" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aae1df220ece3c0ada96b8153459b67eebe9ae9212258bb0134ae60416fdf76" -dependencies = [ - "libc", - "libloading 0.8.9", - "pkg-config", -] - -[[package]] -name = "khronos_api" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" - [[package]] name = "leb128fmt" version = "0.1.0" @@ -1851,26 +849,6 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" -[[package]] -name = "libloading" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" -dependencies = [ - "cfg-if", - "windows-link", -] - -[[package]] -name = "libloading" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" -dependencies = [ - "cfg-if", - "windows-link", -] - [[package]] name = "libm" version = "0.2.16" @@ -1886,7 +864,7 @@ dependencies = [ "bitflags 2.11.0", "libc", "plain", - "redox_syscall 0.7.3", + "redox_syscall", ] [[package]] @@ -1895,55 +873,12 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" -[[package]] -name = "litrs" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" - -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - [[package]] name = "log" version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" -[[package]] -name = "macerator" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09e6046277c48f8a44bd6cfae65a1a261cab6622fb6d4a003f5597e4e4f4a661" -dependencies = [ - "bytemuck", - "cfg_aliases", - "half", - "macerator-macros", - "moddef", - "num-traits", - "paste", - "rustc_version", -] - -[[package]] -name = "macerator-macros" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23ee1819976b67f4d782390c55a75c13401c7a988517f7f8e60a33484dc2e00a" -dependencies = [ - "darling 0.20.11", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "mach2" version = "0.4.3" @@ -1966,12 +901,6 @@ dependencies = [ "thread-tree", ] -[[package]] -name = "md5" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" - [[package]] name = "memchr" version = "2.8.0" @@ -1988,38 +917,6 @@ dependencies = [ "simd-adler32", ] -[[package]] -name = "moddef" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a0b3262dc837d2513fe2ef31ff8461352ef932dcca31ba0c0abe33547cf6b9b" - -[[package]] -name = "naga" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2630921705b9b01dcdd0b6864b9562ca3c1951eecd0f0c4f5f04f61e412647" -dependencies = [ - "arrayvec", - "bit-set", - "bitflags 2.11.0", - "cfg-if", - "cfg_aliases", - "codespan-reporting", - "half", - "hashbrown 0.16.1", - "hexf-parse", - "indexmap", - "libm", - "log", - "num-traits", - "once_cell", - "rustc-hash 1.1.0", - "spirv", - "thiserror", - "unicode-ident", -] - [[package]] name = "native-tls" version = "0.2.18" @@ -2037,21 +934,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "nb" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f" -dependencies = [ - "nb 1.1.0", -] - -[[package]] -name = "nb" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d5439c4ad607c3c23abf66de8c8bf57ba8adcd1f129e699851a6e43935d339d" - [[package]] name = "ndarray" version = "0.17.2" @@ -2098,19 +980,10 @@ dependencies = [ "ndarray", "quickcheck", "rand 0.9.2", - "rand_distr 0.5.1", + "rand_distr", "rand_isaac", ] -[[package]] -name = "ndk-sys" -version = "0.6.0+11769913" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" -dependencies = [ - "jni-sys 0.3.1", -] - [[package]] name = "netlib-src" version = "0.8.0" @@ -2120,39 +993,6 @@ dependencies = [ "cmake", ] -[[package]] -name = "nom" -version = "8.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" -dependencies = [ - "memchr", -] - -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] -name = "num-bigint" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" -dependencies = [ - "num-integer", - "num-traits", -] - [[package]] name = "num-complex" version = "0.4.6" @@ -2171,28 +1011,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -2225,78 +1043,7 @@ dependencies = [ "num-traits", "openblas-src", "rand 0.9.2", - "rand_distr 0.5.1", -] - -[[package]] -name = "objc2" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a12a8ed07aefc768292f076dc3ac8c48f3781c8f2d5851dd3d98950e8c5a89f" -dependencies = [ - "objc2-encode", -] - -[[package]] -name = "objc2-core-foundation" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" -dependencies = [ - "bitflags 2.11.0", - "dispatch2", - "objc2", -] - -[[package]] -name = "objc2-encode" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" - -[[package]] -name = "objc2-foundation" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" -dependencies = [ - "bitflags 2.11.0", - "objc2", - "objc2-core-foundation", -] - -[[package]] -name = "objc2-metal" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794" -dependencies = [ - "bitflags 2.11.0", - "block2", - "objc2", - "objc2-foundation", -] - -[[package]] -name = "objc2-quartz-core" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c1358452b371bf9f104e21ec536d37a650eb10f7ee379fff67d2e08d537f1f" -dependencies = [ - "bitflags 2.11.0", - "objc2", - "objc2-core-foundation", - "objc2-foundation", - "objc2-metal", -] - -[[package]] -name = "object" -version = "0.37.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" -dependencies = [ - "memchr", + "rand_distr", ] [[package]] @@ -2305,12 +1052,6 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" -[[package]] -name = "oneshot" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe21416a02c693fb9f980befcb230ecc70b0b3d1cc4abf88b9675c4c1457f0c" - [[package]] name = "oorandom" version = "11.1.5" @@ -2393,15 +1134,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" -[[package]] -name = "ordered-float" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" -dependencies = [ - "num-traits", -] - [[package]] name = "p64" version = "0.1.0" @@ -2410,35 +1142,6 @@ dependencies = [ "fractal", ] -[[package]] -name = "parking" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" - -[[package]] -name = "parking_lot" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.5.18", - "smallvec", - "windows-link", -] - [[package]] name = "paste" version = "1.0.15" @@ -2460,24 +1163,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "petgraph" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" -dependencies = [ - "fixedbitset", - "hashbrown 0.15.5", - "indexmap", - "serde", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" - [[package]] name = "pkg-config" version = "0.3.32" @@ -2525,7 +1210,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" dependencies = [ "critical-section", - "serde", ] [[package]] @@ -2546,12 +1230,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "presser" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" - [[package]] name = "prettyplease" version = "0.2.37" @@ -2571,12 +1249,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "profiling" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" - [[package]] name = "quickcheck" version = "1.1.0" @@ -2623,7 +1295,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" dependencies = [ - "chacha20", "getrandom 0.4.2", "rand_core 0.10.0", ] @@ -2663,16 +1334,6 @@ dependencies = [ "rand 0.9.2", ] -[[package]] -name = "rand_distr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d431c2703ccf129de4d45253c03f49ebb22b97d6ad79ee3ecfc7e3f4862c1d8" -dependencies = [ - "num-traits", - "rand 0.10.0", -] - [[package]] name = "rand_isaac" version = "0.4.0" @@ -2682,30 +1343,6 @@ dependencies = [ "rand_core 0.9.5", ] -[[package]] -name = "range-alloc" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" - -[[package]] -name = "raw-window-handle" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" - -[[package]] -name = "raw-window-metal" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" -dependencies = [ - "objc2", - "objc2-core-foundation", - "objc2-foundation", - "objc2-quartz-core", -] - [[package]] name = "rawpointer" version = "0.2.1" @@ -2732,15 +1369,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "redox_syscall" version = "0.7.3" @@ -2771,7 +1399,7 @@ dependencies = [ "bumpalo", "hashbrown 0.15.5", "log", - "rustc-hash 2.1.2", + "rustc-hash", "smallvec", ] @@ -2816,12 +1444,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "renderdoc-sys" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" - [[package]] name = "rmp" version = "0.8.13" @@ -2856,33 +1478,12 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "rustc-demangle" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" - -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "1.1.4" @@ -2920,15 +1521,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "sanitize-filename" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc984f4f9ceb736a7bb755c3e3bd17dc56370af2600c9780dcc48c66453da34d" -dependencies = [ - "regex", -] - [[package]] name = "schannel" version = "0.1.29" @@ -2938,12 +1530,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "security-framework" version = "3.7.0" @@ -2973,12 +1559,6 @@ version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" -[[package]] -name = "seq-macro" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" - [[package]] name = "serde" version = "1.0.228" @@ -2989,16 +1569,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde_bytes" -version = "0.11.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" -dependencies = [ - "serde", - "serde_core", -] - [[package]] name = "serde_core" version = "1.0.228" @@ -3032,15 +1602,6 @@ dependencies = [ "zmij", ] -[[package]] -name = "serde_spanned" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98" -dependencies = [ - "serde_core", -] - [[package]] name = "serialization-tests" version = "0.1.0" @@ -3065,72 +1626,17 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" -[[package]] -name = "slab" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" - -[[package]] -name = "slotmap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038" -dependencies = [ - "version_check", -] - [[package]] name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -dependencies = [ - "serde", -] - -[[package]] -name = "spin" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" -dependencies = [ - "lock_api", - "portable-atomic", -] - -[[package]] -name = "spirv" -version = "0.4.0+sdk-1.4.341.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9571ea910ebd84c86af4b3ed27f9dbdc6ad06f17c5f96146b2b671e2976744f" -dependencies = [ - "bitflags 2.11.0", -] - -[[package]] -name = "stable-vec" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dac7bc0f7d0d44329b200020effbc25a534d89fa142af95e3ddf76113412a5e" - -[[package]] -name = "stable_deref_trait" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" - -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] -name = "strsim" -version = "0.11.1" +name = "stable_deref_trait" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "syn" @@ -3173,15 +1679,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" version = "2.0.18" @@ -3211,15 +1708,6 @@ dependencies = [ "crossbeam-channel", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -3230,93 +1718,18 @@ dependencies = [ "serde_json", ] -[[package]] -name = "toml" -version = "1.1.0+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc" -dependencies = [ - "indexmap", - "serde_core", - "serde_spanned", - "toml_datetime", - "toml_parser", - "toml_writer", - "winnow", -] - -[[package]] -name = "toml_datetime" -version = "1.1.0+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" -dependencies = [ - "serde_core", -] - -[[package]] -name = "toml_parser" -version = "1.1.0+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" -dependencies = [ - "winnow", -] - -[[package]] -name = "toml_writer" -version = "1.1.0+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" - -[[package]] -name = "tynm" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21cdb0fc8f85c98b1ec812bc4cd69faf6c0fa2fc17d44ea3c2cdd38dc08e999" -dependencies = [ - "nom", -] - -[[package]] -name = "type-map" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb30dbbd9036155e74adad6812e9898d03ec374946234fbcebd5dfc7b9187b90" -dependencies = [ - "rustc-hash 2.1.2", -] - [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" -[[package]] -name = "unicode-segmentation" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" - -[[package]] -name = "unicode-width" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" - [[package]] name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "unty" -version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" - [[package]] name = "ureq" version = "3.3.0" @@ -3352,35 +1765,12 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" -[[package]] -name = "variadics_please" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "void" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" - [[package]] name = "walkdir" version = "2.5.0" @@ -3428,16 +1818,6 @@ dependencies = [ "wasm-bindgen-shared", ] -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.65" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d1faf851e778dfa54db7cd438b70758eba9755cb47403f3496edd7c8fc212f0" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "wasm-bindgen-macro" version = "0.2.115" @@ -3516,18 +1896,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "wayland-sys" -version = "0.31.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374f6b70e8e0d6bf9461a32988fd553b59ff630964924dad6e4a4eb6bd538d17" -dependencies = [ - "dlib", - "log", - "once_cell", - "pkg-config", -] - [[package]] name = "web-sys" version = "0.3.92" @@ -3538,16 +1906,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "web-time" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "webpki-root-certs" version = "1.0.6" @@ -3557,173 +1915,6 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "wgpu" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72c239a9a747bbd379590985bac952c2e53cb19873f7072b3370c6a6a8e06837" -dependencies = [ - "arrayvec", - "bitflags 2.11.0", - "bytemuck", - "cfg-if", - "cfg_aliases", - "document-features", - "hashbrown 0.16.1", - "js-sys", - "log", - "naga", - "parking_lot", - "portable-atomic", - "profiling", - "raw-window-handle", - "smallvec", - "static_assertions", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "wgpu-core", - "wgpu-hal", - "wgpu-types", -] - -[[package]] -name = "wgpu-core" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e80ac6cf1895df6342f87d975162108f9d98772a0d74bc404ab7304ac29469e" -dependencies = [ - "arrayvec", - "bit-set", - "bit-vec", - "bitflags 2.11.0", - "bytemuck", - "cfg_aliases", - "document-features", - "hashbrown 0.16.1", - "indexmap", - "log", - "naga", - "once_cell", - "parking_lot", - "portable-atomic", - "profiling", - "raw-window-handle", - "rustc-hash 1.1.0", - "smallvec", - "thiserror", - "wgpu-core-deps-apple", - "wgpu-core-deps-emscripten", - "wgpu-core-deps-windows-linux-android", - "wgpu-hal", - "wgpu-naga-bridge", - "wgpu-types", -] - -[[package]] -name = "wgpu-core-deps-apple" -version = "29.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43acd053312501689cd92a01a9638d37f3e41a5fd9534875efa8917ee2d11ac0" -dependencies = [ - "wgpu-hal", -] - -[[package]] -name = "wgpu-core-deps-emscripten" -version = "29.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef043bf135cc68b6f667c55ff4e345ce2b5924d75bad36a47921b0287ca4b24a" -dependencies = [ - "wgpu-hal", -] - -[[package]] -name = "wgpu-core-deps-windows-linux-android" -version = "29.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "725d5c006a8c02967b6d93ef04f6537ec4593313e330cfe86d9d3f946eb90f28" -dependencies = [ - "wgpu-hal", -] - -[[package]] -name = "wgpu-hal" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89a47aef47636562f3937285af4c44b4b5b404b46577471411cc5313a921da7e" -dependencies = [ - "android_system_properties", - "arrayvec", - "ash", - "bit-set", - "bitflags 2.11.0", - "block2", - "bytemuck", - "cfg-if", - "cfg_aliases", - "glow", - "glutin_wgl_sys", - "gpu-allocator", - "gpu-descriptor", - "hashbrown 0.16.1", - "js-sys", - "khronos-egl", - "libc", - "libloading 0.8.9", - "log", - "naga", - "ndk-sys", - "objc2", - "objc2-core-foundation", - "objc2-foundation", - "objc2-metal", - "objc2-quartz-core", - "once_cell", - "ordered-float", - "parking_lot", - "portable-atomic", - "portable-atomic-util", - "profiling", - "range-alloc", - "raw-window-handle", - "raw-window-metal", - "renderdoc-sys", - "smallvec", - "thiserror", - "wasm-bindgen", - "wayland-sys", - "web-sys", - "wgpu-naga-bridge", - "wgpu-types", - "windows", - "windows-core", -] - -[[package]] -name = "wgpu-naga-bridge" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4684f4410da0cf95a4cb63bb5edaac022461dedb6adf0b64d0d9b5f6890d51" -dependencies = [ - "naga", - "wgpu-types", -] - -[[package]] -name = "wgpu-types" -version = "29.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec2675540fb1a5cfa5ef122d3d5f390e2c75711a0b946410f2d6ac3a0f77d1f6" -dependencies = [ - "bitflags 2.11.0", - "bytemuck", - "js-sys", - "log", - "raw-window-handle", - "web-sys", -] - [[package]] name = "winapi-util" version = "0.1.11" @@ -3733,107 +1924,12 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "windows" -version = "0.62.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" -dependencies = [ - "windows-collections", - "windows-core", - "windows-future", - "windows-numerics", -] - -[[package]] -name = "windows-collections" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" -dependencies = [ - "windows-core", -] - -[[package]] -name = "windows-core" -version = "0.62.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link", - "windows-result", - "windows-strings", -] - -[[package]] -name = "windows-future" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" -dependencies = [ - "windows-core", - "windows-link", - "windows-threading", -] - -[[package]] -name = "windows-implement" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-interface" -version = "0.59.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" -[[package]] -name = "windows-numerics" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" -dependencies = [ - "windows-core", - "windows-link", -] - -[[package]] -name = "windows-result" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-strings" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" -dependencies = [ - "windows-link", -] - [[package]] name = "windows-sys" version = "0.52.0" @@ -3877,15 +1973,6 @@ dependencies = [ "windows_x86_64_msvc", ] -[[package]] -name = "windows-threading" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" -dependencies = [ - "windows-link", -] - [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -3934,12 +2021,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "winnow" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" - [[package]] name = "wit-bindgen" version = "0.51.0" @@ -4038,18 +2119,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "xml-rs" -version = "0.8.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae8337f8a065cfc972643663ea4279e04e7256de865aa66fe25cec5fb912d3f" - -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - [[package]] name = "zerocopy" version = "0.8.48" diff --git a/Cargo.toml b/Cargo.toml index 8d17bb6a..ea01bc3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,7 +144,7 @@ members = [ "ndarray-rand", "crates/*", ] -exclude = [] +exclude = ["crates/burn"] default-members = [ ".", "ndarray-rand", diff --git a/benches/append.rs b/benches/append.rs index a37df256..b9ca99c6 100644 --- a/benches/append.rs +++ b/benches/append.rs @@ -6,24 +6,21 @@ use test::Bencher; use ndarray::prelude::*; #[bench] -fn select_axis0(bench: &mut Bencher) -{ +fn select_axis0(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let selectable = vec![0, 1, 2, 0, 1, 3, 0, 4, 16, 32, 128, 147, 149, 220, 221, 255, 221, 0, 1]; bench.iter(|| a.select(Axis(0), &selectable)); } #[bench] -fn select_axis1(bench: &mut Bencher) -{ +fn select_axis1(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let selectable = vec![0, 1, 2, 0, 1, 3, 0, 4, 16, 32, 128, 147, 149, 220, 221, 255, 221, 0, 1]; bench.iter(|| a.select(Axis(1), &selectable)); } #[bench] -fn select_1d(bench: &mut Bencher) -{ +fn select_1d(bench: &mut Bencher) { let a = Array::::zeros(1024); let mut selectable = (0..a.len()).step_by(17).collect::>(); selectable.extend(selectable.clone().iter().rev()); diff --git a/benches/bench1.rs b/benches/bench1.rs index 3b540532..e534c04d 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -14,8 +14,7 @@ use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn}; use test::black_box; #[bench] -fn iter_sum_1d_regular(bench: &mut test::Bencher) -{ +fn iter_sum_1d_regular(bench: &mut test::Bencher) { let a = Array::::zeros(64 * 64); let a = black_box(a); bench.iter(|| { @@ -28,8 +27,7 @@ fn iter_sum_1d_regular(bench: &mut test::Bencher) } #[bench] -fn iter_sum_1d_raw(bench: &mut test::Bencher) -{ +fn iter_sum_1d_raw(bench: &mut test::Bencher) { // this is autovectorized to death (= great performance) let a = Array::::zeros(64 * 64); let a = black_box(a); @@ -43,8 +41,7 @@ fn iter_sum_1d_raw(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_regular(bench: &mut test::Bencher) -{ +fn iter_sum_2d_regular(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| { @@ -57,8 +54,7 @@ fn iter_sum_2d_regular(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_by_row(bench: &mut test::Bencher) -{ +fn iter_sum_2d_by_row(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| { @@ -73,8 +69,7 @@ fn iter_sum_2d_by_row(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_raw(bench: &mut test::Bencher) -{ +fn iter_sum_2d_raw(bench: &mut test::Bencher) { // this is autovectorized to death (= great performance) let a = Array::::zeros((64, 64)); let a = black_box(a); @@ -88,8 +83,7 @@ fn iter_sum_2d_raw(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_cutout(bench: &mut test::Bencher) -{ +fn iter_sum_2d_cutout(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -103,8 +97,7 @@ fn iter_sum_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) -{ +fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -120,8 +113,7 @@ fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) -{ +fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -137,8 +129,7 @@ fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) -{ +fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) { let mut a = Array::::zeros((64, 64)); a.swap_axes(0, 1); let a = black_box(a); @@ -152,8 +143,7 @@ fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) } #[bench] -fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) -{ +fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) { let mut a = Array::::zeros((64, 64)); a.swap_axes(0, 1); let a = black_box(a); @@ -169,16 +159,14 @@ fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) } #[bench] -fn sum_2d_regular(bench: &mut test::Bencher) -{ +fn sum_2d_regular(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| a.sum()); } #[bench] -fn sum_2d_cutout(bench: &mut test::Bencher) -{ +fn sum_2d_cutout(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -186,16 +174,14 @@ fn sum_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn sum_2d_float(bench: &mut test::Bencher) -{ +fn sum_2d_float(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let a = black_box(a.view()); bench.iter(|| a.sum()); } #[bench] -fn sum_2d_float_cutout(bench: &mut test::Bencher) -{ +fn sum_2d_float_cutout(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -203,8 +189,7 @@ fn sum_2d_float_cutout(bench: &mut test::Bencher) } #[bench] -fn sum_2d_float_t_cutout(bench: &mut test::Bencher) -{ +fn sum_2d_float_t_cutout(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]).reversed_axes(); let a = black_box(av); @@ -212,15 +197,13 @@ fn sum_2d_float_t_cutout(bench: &mut test::Bencher) } #[bench] -fn fold_sum_i32_2d_regular(bench: &mut test::Bencher) -{ +fn fold_sum_i32_2d_regular(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| a.fold(0, |acc, &x| acc + x)); } #[bench] -fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) -{ +fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -228,8 +211,7 @@ fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) -{ +fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) { let a = Array::::zeros((64, 128)); let av = a.slice(s![.., ..;2]); let a = black_box(av); @@ -237,16 +219,14 @@ fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) } #[bench] -fn fold_sum_i32_2d_transpose(bench: &mut test::Bencher) -{ +fn fold_sum_i32_2d_transpose(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let a = a.t(); bench.iter(|| a.fold(0, |acc, &x| acc + x)); } #[bench] -fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) -{ +fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) { let a = Array::::zeros((66, 66)); let mut av = a.slice(s![1..-1, 1..-1]); av.swap_axes(0, 1); @@ -257,8 +237,7 @@ fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) const ADD2DSZ: usize = 64; #[bench] -fn add_2d_regular(bench: &mut test::Bencher) -{ +fn add_2d_regular(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -268,8 +247,7 @@ fn add_2d_regular(bench: &mut test::Bencher) } #[bench] -fn add_2d_zip(bench: &mut test::Bencher) -{ +fn add_2d_zip(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { @@ -278,16 +256,14 @@ fn add_2d_zip(bench: &mut test::Bencher) } #[bench] -fn add_2d_alloc_plus(bench: &mut test::Bencher) -{ +fn add_2d_alloc_plus(bench: &mut test::Bencher) { let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| &a + &b); } #[bench] -fn add_2d_alloc_zip_uninit(bench: &mut test::Bencher) -{ +fn add_2d_alloc_zip_uninit(bench: &mut test::Bencher) { let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| unsafe { @@ -300,44 +276,38 @@ fn add_2d_alloc_zip_uninit(bench: &mut test::Bencher) } #[bench] -fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) -{ +fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) { let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| Zip::from(&a).and(&b).map_collect(|&x, &y| x + y)); } #[bench] -fn vec_string_collect(bench: &mut test::Bencher) -{ +fn vec_string_collect(bench: &mut test::Bencher) { let v = vec![""; 10240]; bench.iter(|| v.iter().map(|s| s.to_owned()).collect::>()); } #[bench] -fn array_string_collect(bench: &mut test::Bencher) -{ +fn array_string_collect(bench: &mut test::Bencher) { let v = Array::from(vec![""; 10240]); bench.iter(|| Zip::from(&v).map_collect(|s| s.to_owned())); } #[bench] -fn vec_f64_collect(bench: &mut test::Bencher) -{ +fn vec_f64_collect(bench: &mut test::Bencher) { let v = vec![1.; 10240]; bench.iter(|| v.iter().map(|s| s + 1.).collect::>()); } #[bench] -fn array_f64_collect(bench: &mut test::Bencher) -{ +fn array_f64_collect(bench: &mut test::Bencher) { let v = Array::from(vec![1.; 10240]); bench.iter(|| Zip::from(&v).map_collect(|s| s + 1.)); } #[bench] -fn add_2d_assign_ops(bench: &mut test::Bencher) -{ +fn add_2d_assign_ops(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -349,8 +319,7 @@ fn add_2d_assign_ops(bench: &mut test::Bencher) } #[bench] -fn add_2d_cutout(bench: &mut test::Bencher) -{ +fn add_2d_cutout(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ + 2, ADD2DSZ + 2)); let mut acut = a.slice_mut(s![1..-1, 1..-1]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -361,8 +330,7 @@ fn add_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn add_2d_zip_cutout(bench: &mut test::Bencher) -{ +fn add_2d_zip_cutout(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ + 2, ADD2DSZ + 2)); let mut acut = a.slice_mut(s![1..-1, 1..-1]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -373,8 +341,7 @@ fn add_2d_zip_cutout(bench: &mut test::Bencher) #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_4(bench: &mut test::Bencher) -{ +fn add_2d_cutouts_by_4(bench: &mut test::Bencher) { let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (4, 4); @@ -387,8 +354,7 @@ fn add_2d_cutouts_by_4(bench: &mut test::Bencher) #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_16(bench: &mut test::Bencher) -{ +fn add_2d_cutouts_by_16(bench: &mut test::Bencher) { let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (16, 16); @@ -401,8 +367,7 @@ fn add_2d_cutouts_by_16(bench: &mut test::Bencher) #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_32(bench: &mut test::Bencher) -{ +fn add_2d_cutouts_by_32(bench: &mut test::Bencher) { let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (32, 32); @@ -414,8 +379,7 @@ fn add_2d_cutouts_by_32(bench: &mut test::Bencher) } #[bench] -fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) -{ +fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) { let mut a = Array2::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array1::::zeros(ADD2DSZ); let bv = b.view(); @@ -425,8 +389,7 @@ fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) } #[bench] -fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) -{ +fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros(()); let bv = b.view(); @@ -436,55 +399,48 @@ fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) } #[bench] -fn scalar_toowned(bench: &mut test::Bencher) -{ +fn scalar_toowned(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| a.to_owned()); } #[bench] -fn scalar_add_1(bench: &mut test::Bencher) -{ +fn scalar_add_1(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| &a + n); } #[bench] -fn scalar_add_2(bench: &mut test::Bencher) -{ +fn scalar_add_2(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| n + &a); } #[bench] -fn scalar_add_strided_1(bench: &mut test::Bencher) -{ +fn scalar_add_strided_1(bench: &mut test::Bencher) { let a = Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); let n = 1.; bench.iter(|| &a + n); } #[bench] -fn scalar_add_strided_2(bench: &mut test::Bencher) -{ +fn scalar_add_strided_2(bench: &mut test::Bencher) { let a = Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); let n = 1.; bench.iter(|| n + &a); } #[bench] -fn scalar_sub_1(bench: &mut test::Bencher) -{ +fn scalar_sub_1(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| &a - n); } #[bench] -fn scalar_sub_2(bench: &mut test::Bencher) -{ +fn scalar_sub_2(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| n - &a); @@ -492,8 +448,7 @@ fn scalar_sub_2(bench: &mut test::Bencher) // This is for comparison with add_2d_broadcast_0_to_2 #[bench] -fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) -{ +fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let n = black_box(0); bench.iter(|| { @@ -502,8 +457,7 @@ fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) } #[bench] -fn add_2d_strided(bench: &mut test::Bencher) -{ +fn add_2d_strided(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -514,8 +468,7 @@ fn add_2d_strided(bench: &mut test::Bencher) } #[bench] -fn add_2d_regular_dyn(bench: &mut test::Bencher) -{ +fn add_2d_regular_dyn(bench: &mut test::Bencher) { let mut a = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); let b = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); let bv = b.view(); @@ -525,8 +478,7 @@ fn add_2d_regular_dyn(bench: &mut test::Bencher) } #[bench] -fn add_2d_strided_dyn(bench: &mut test::Bencher) -{ +fn add_2d_strided_dyn(bench: &mut test::Bencher) { let mut a = Array::::zeros(&[ADD2DSZ, ADD2DSZ * 2][..]); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); @@ -537,8 +489,7 @@ fn add_2d_strided_dyn(bench: &mut test::Bencher) } #[bench] -fn add_2d_zip_strided(bench: &mut test::Bencher) -{ +fn add_2d_zip_strided(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -548,8 +499,7 @@ fn add_2d_zip_strided(bench: &mut test::Bencher) } #[bench] -fn add_2d_one_transposed(bench: &mut test::Bencher) -{ +fn add_2d_one_transposed(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -559,8 +509,7 @@ fn add_2d_one_transposed(bench: &mut test::Bencher) } #[bench] -fn add_2d_zip_one_transposed(bench: &mut test::Bencher) -{ +fn add_2d_zip_one_transposed(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -570,8 +519,7 @@ fn add_2d_zip_one_transposed(bench: &mut test::Bencher) } #[bench] -fn add_2d_both_transposed(bench: &mut test::Bencher) -{ +fn add_2d_both_transposed(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -582,8 +530,7 @@ fn add_2d_both_transposed(bench: &mut test::Bencher) } #[bench] -fn add_2d_zip_both_transposed(bench: &mut test::Bencher) -{ +fn add_2d_zip_both_transposed(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -594,8 +541,7 @@ fn add_2d_zip_both_transposed(bench: &mut test::Bencher) } #[bench] -fn add_2d_f32_regular(bench: &mut test::Bencher) -{ +fn add_2d_f32_regular(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -607,8 +553,7 @@ fn add_2d_f32_regular(bench: &mut test::Bencher) const ADD3DSZ: usize = 16; #[bench] -fn add_3d_strided(bench: &mut test::Bencher) -{ +fn add_3d_strided(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD3DSZ, ADD3DSZ, ADD3DSZ * 2)); let mut a = a.slice_mut(s![.., .., ..;2]); let b = Array::::zeros(a.dim()); @@ -619,8 +564,7 @@ fn add_3d_strided(bench: &mut test::Bencher) } #[bench] -fn add_3d_strided_dyn(bench: &mut test::Bencher) -{ +fn add_3d_strided_dyn(bench: &mut test::Bencher) { let mut a = Array::::zeros(&[ADD3DSZ, ADD3DSZ, ADD3DSZ * 2][..]); let mut a = a.slice_mut(s![.., .., ..;2]); let b = Array::::zeros(a.dim()); @@ -633,8 +577,7 @@ fn add_3d_strided_dyn(bench: &mut test::Bencher) const ADD1D_SIZE: usize = 64 * 64; #[bench] -fn add_1d_regular(bench: &mut test::Bencher) -{ +fn add_1d_regular(bench: &mut test::Bencher) { let mut a = Array::::zeros(ADD1D_SIZE); let b = Array::::zeros(a.dim()); bench.iter(|| { @@ -643,8 +586,7 @@ fn add_1d_regular(bench: &mut test::Bencher) } #[bench] -fn add_1d_strided(bench: &mut test::Bencher) -{ +fn add_1d_strided(bench: &mut test::Bencher) { let mut a = Array::::zeros(ADD1D_SIZE * 2); let mut av = a.slice_mut(s![..;2]); let b = Array::::zeros(av.dim()); @@ -654,8 +596,7 @@ fn add_1d_strided(bench: &mut test::Bencher) } #[bench] -fn iadd_scalar_2d_regular(bench: &mut test::Bencher) -{ +fn iadd_scalar_2d_regular(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { a += 1.; @@ -663,8 +604,7 @@ fn iadd_scalar_2d_regular(bench: &mut test::Bencher) } #[bench] -fn iadd_scalar_2d_strided(bench: &mut test::Bencher) -{ +fn iadd_scalar_2d_strided(bench: &mut test::Bencher) { let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); bench.iter(|| { @@ -673,8 +613,7 @@ fn iadd_scalar_2d_strided(bench: &mut test::Bencher) } #[bench] -fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) -{ +fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) { let mut a = Array::::zeros(vec![ADD2DSZ, ADD2DSZ]); bench.iter(|| { a += 1.; @@ -682,8 +621,7 @@ fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) } #[bench] -fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) -{ +fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) { let mut a = Array::::zeros(vec![ADD2DSZ, ADD2DSZ * 2]); let mut a = a.slice_mut(s![.., ..;2]); bench.iter(|| { @@ -692,8 +630,7 @@ fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) } #[bench] -fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) -{ +fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) { let mut av = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = Array::::zeros((ADD2DSZ, ADD2DSZ)); let scalar = std::f32::consts::PI; @@ -703,8 +640,7 @@ fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) } #[bench] -fn assign_scalar_2d_corder(bench: &mut test::Bencher) -{ +fn assign_scalar_2d_corder(bench: &mut test::Bencher) { let a = Array::zeros((ADD2DSZ, ADD2DSZ)); let mut a = black_box(a); let s = 3.; @@ -712,8 +648,7 @@ fn assign_scalar_2d_corder(bench: &mut test::Bencher) } #[bench] -fn assign_scalar_2d_cutout(bench: &mut test::Bencher) -{ +fn assign_scalar_2d_cutout(bench: &mut test::Bencher) { let mut a = Array::zeros((66, 66)); let a = a.slice_mut(s![1..-1, 1..-1]); let mut a = black_box(a); @@ -722,8 +657,7 @@ fn assign_scalar_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn assign_scalar_2d_forder(bench: &mut test::Bencher) -{ +fn assign_scalar_2d_forder(bench: &mut test::Bencher) { let mut a = Array::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut a = black_box(a); @@ -732,16 +666,14 @@ fn assign_scalar_2d_forder(bench: &mut test::Bencher) } #[bench] -fn assign_zero_2d_corder(bench: &mut test::Bencher) -{ +fn assign_zero_2d_corder(bench: &mut test::Bencher) { let a = Array::zeros((ADD2DSZ, ADD2DSZ)); let mut a = black_box(a); bench.iter(|| a.fill(0.)) } #[bench] -fn assign_zero_2d_cutout(bench: &mut test::Bencher) -{ +fn assign_zero_2d_cutout(bench: &mut test::Bencher) { let mut a = Array::zeros((66, 66)); let a = a.slice_mut(s![1..-1, 1..-1]); let mut a = black_box(a); @@ -749,8 +681,7 @@ fn assign_zero_2d_cutout(bench: &mut test::Bencher) } #[bench] -fn assign_zero_2d_forder(bench: &mut test::Bencher) -{ +fn assign_zero_2d_forder(bench: &mut test::Bencher) { let mut a = Array::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut a = black_box(a); @@ -758,8 +689,7 @@ fn assign_zero_2d_forder(bench: &mut test::Bencher) } #[bench] -fn bench_iter_diag(bench: &mut test::Bencher) -{ +fn bench_iter_diag(bench: &mut test::Bencher) { let a = Array::::zeros((1024, 1024)); bench.iter(|| { for elt in a.diag() { @@ -769,8 +699,7 @@ fn bench_iter_diag(bench: &mut test::Bencher) } #[bench] -fn bench_row_iter(bench: &mut test::Bencher) -{ +fn bench_row_iter(bench: &mut test::Bencher) { let a = Array::::zeros((1024, 1024)); let it = a.row(17); bench.iter(|| { @@ -781,8 +710,7 @@ fn bench_row_iter(bench: &mut test::Bencher) } #[bench] -fn bench_col_iter(bench: &mut test::Bencher) -{ +fn bench_col_iter(bench: &mut test::Bencher) { let a = Array::::zeros((1024, 1024)); let it = a.column(17); bench.iter(|| { @@ -852,8 +780,7 @@ mat_mul! {mat_mul_i32, i32, } #[bench] -fn create_iter_4d(bench: &mut test::Bencher) -{ +fn create_iter_4d(bench: &mut test::Bencher) { let mut a = Array::from_elem((4, 5, 3, 2), 1.0); a.swap_axes(0, 1); a.swap_axes(2, 1); @@ -863,94 +790,82 @@ fn create_iter_4d(bench: &mut test::Bencher) } #[bench] -fn bench_to_owned_n(bench: &mut test::Bencher) -{ +fn bench_to_owned_n(bench: &mut test::Bencher) { let a = Array::::zeros((32, 32)); bench.iter(|| a.to_owned()); } #[bench] -fn bench_to_owned_t(bench: &mut test::Bencher) -{ +fn bench_to_owned_t(bench: &mut test::Bencher) { let mut a = Array::::zeros((32, 32)); a.swap_axes(0, 1); bench.iter(|| a.to_owned()); } #[bench] -fn bench_to_owned_strided(bench: &mut test::Bencher) -{ +fn bench_to_owned_strided(bench: &mut test::Bencher) { let a = Array::::zeros((32, 64)); let a = a.slice(s![.., ..;2]); bench.iter(|| a.to_owned()); } #[bench] -fn equality_i32(bench: &mut test::Bencher) -{ +fn equality_i32(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64)); bench.iter(|| a == b); } #[bench] -fn equality_f32(bench: &mut test::Bencher) -{ +fn equality_f32(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64)); bench.iter(|| a == b); } #[bench] -fn equality_f32_mixorder(bench: &mut test::Bencher) -{ +fn equality_f32_mixorder(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64).f()); bench.iter(|| a == b); } #[bench] -fn dot_f32_16(bench: &mut test::Bencher) -{ +fn dot_f32_16(bench: &mut test::Bencher) { let a = Array::::zeros(16); let b = Array::::zeros(16); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_20(bench: &mut test::Bencher) -{ +fn dot_f32_20(bench: &mut test::Bencher) { let a = Array::::zeros(20); let b = Array::::zeros(20); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_32(bench: &mut test::Bencher) -{ +fn dot_f32_32(bench: &mut test::Bencher) { let a = Array::::zeros(32); let b = Array::::zeros(32); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_256(bench: &mut test::Bencher) -{ +fn dot_f32_256(bench: &mut test::Bencher) { let a = Array::::zeros(256); let b = Array::::zeros(256); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_1024(bench: &mut test::Bencher) -{ +fn dot_f32_1024(bench: &mut test::Bencher) { let av = Array::::zeros(1024); let bv = Array::::zeros(1024); bench.iter(|| av.dot(&bv)); } #[bench] -fn dot_f32_10e6(bench: &mut test::Bencher) -{ +fn dot_f32_10e6(bench: &mut test::Bencher) { let n = 1_000_000; let av = Array::::zeros(n); let bv = Array::::zeros(n); @@ -958,8 +873,7 @@ fn dot_f32_10e6(bench: &mut test::Bencher) } #[bench] -fn dot_extended(bench: &mut test::Bencher) -{ +fn dot_extended(bench: &mut test::Bencher) { let m = 10; let n = 33; let k = 10; @@ -981,8 +895,7 @@ fn dot_extended(bench: &mut test::Bencher) const MEAN_SUM_N: usize = 127; #[cfg(feature = "std")] -fn range_mat(m: Ix, n: Ix) -> Array2 -{ +fn range_mat(m: Ix, n: Ix) -> Array2 { assert!(m * n != 0); Array::linspace(0.0..=(m * n - 1) as f32, m * n) .into_shape_with_order((m, n)) @@ -991,103 +904,90 @@ fn range_mat(m: Ix, n: Ix) -> Array2 #[cfg(feature = "std")] #[bench] -fn mean_axis0(bench: &mut test::Bencher) -{ +fn mean_axis0(bench: &mut test::Bencher) { let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.mean_axis(Axis(0))); } #[cfg(feature = "std")] #[bench] -fn mean_axis1(bench: &mut test::Bencher) -{ +fn mean_axis1(bench: &mut test::Bencher) { let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.mean_axis(Axis(1))); } #[cfg(feature = "std")] #[bench] -fn sum_axis0(bench: &mut test::Bencher) -{ +fn sum_axis0(bench: &mut test::Bencher) { let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.sum_axis(Axis(0))); } #[cfg(feature = "std")] #[bench] -fn sum_axis1(bench: &mut test::Bencher) -{ +fn sum_axis1(bench: &mut test::Bencher) { let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.sum_axis(Axis(1))); } #[bench] -fn into_dimensionality_ix1_ok(bench: &mut test::Bencher) -{ +fn into_dimensionality_ix1_ok(bench: &mut test::Bencher) { let a = Array::::zeros(Ix1(10)); let a = a.view(); bench.iter(|| a.into_dimensionality::()); } #[bench] -fn into_dimensionality_ix3_ok(bench: &mut test::Bencher) -{ +fn into_dimensionality_ix3_ok(bench: &mut test::Bencher) { let a = Array::::zeros(Ix3(10, 10, 10)); let a = a.view(); bench.iter(|| a.into_dimensionality::()); } #[bench] -fn into_dimensionality_ix3_err(bench: &mut test::Bencher) -{ +fn into_dimensionality_ix3_err(bench: &mut test::Bencher) { let a = Array::::zeros(Ix3(10, 10, 10)); let a = a.view(); bench.iter(|| a.into_dimensionality::()); } #[bench] -fn into_dimensionality_dyn_to_ix3(bench: &mut test::Bencher) -{ +fn into_dimensionality_dyn_to_ix3(bench: &mut test::Bencher) { let a = Array::::zeros(IxDyn(&[10, 10, 10])); let a = a.view(); bench.iter(|| a.clone().into_dimensionality::()); } #[bench] -fn into_dimensionality_dyn_to_dyn(bench: &mut test::Bencher) -{ +fn into_dimensionality_dyn_to_dyn(bench: &mut test::Bencher) { let a = Array::::zeros(IxDyn(&[10, 10, 10])); let a = a.view(); bench.iter(|| a.clone().into_dimensionality::()); } #[bench] -fn into_dyn_ix3(bench: &mut test::Bencher) -{ +fn into_dyn_ix3(bench: &mut test::Bencher) { let a = Array::::zeros(Ix3(10, 10, 10)); let a = a.view(); bench.iter(|| a.into_dyn()); } #[bench] -fn into_dyn_ix5(bench: &mut test::Bencher) -{ +fn into_dyn_ix5(bench: &mut test::Bencher) { let a = Array::::zeros(Ix5(2, 2, 2, 2, 2)); let a = a.view(); bench.iter(|| a.into_dyn()); } #[bench] -fn into_dyn_dyn(bench: &mut test::Bencher) -{ +fn into_dyn_dyn(bench: &mut test::Bencher) { let a = Array::::zeros(IxDyn(&[10, 10, 10])); let a = a.view(); bench.iter(|| a.clone().into_dyn()); } #[bench] -fn broadcast_same_dim(bench: &mut test::Bencher) -{ +fn broadcast_same_dim(bench: &mut test::Bencher) { let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let s = Array4::from_shape_vec((2, 2, 3, 2), s.to_vec()).unwrap(); let a = s.slice(s![.., ..;-1, ..;2, ..]); @@ -1096,8 +996,7 @@ fn broadcast_same_dim(bench: &mut test::Bencher) } #[bench] -fn broadcast_one_side(bench: &mut test::Bencher) -{ +fn broadcast_one_side(bench: &mut test::Bencher) { let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let s2 = [1, 2, 3, 4, 5, 6]; let a = Array4::from_shape_vec((4, 1, 3, 2), s.to_vec()).unwrap(); diff --git a/benches/chunks.rs b/benches/chunks.rs index 46780492..5ea9ba46 100644 --- a/benches/chunks.rs +++ b/benches/chunks.rs @@ -7,8 +7,7 @@ use ndarray::prelude::*; use ndarray::NdProducer; #[bench] -fn chunk2x2_iter_sum(bench: &mut Bencher) -{ +fn chunk2x2_iter_sum(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -20,8 +19,7 @@ fn chunk2x2_iter_sum(bench: &mut Bencher) } #[bench] -fn chunk2x2_sum(bench: &mut Bencher) -{ +fn chunk2x2_sum(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -33,8 +31,7 @@ fn chunk2x2_sum(bench: &mut Bencher) } #[bench] -fn chunk2x2_sum_get1(bench: &mut Bencher) -{ +fn chunk2x2_sum_get1(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -49,8 +46,7 @@ fn chunk2x2_sum_get1(bench: &mut Bencher) } #[bench] -fn chunk2x2_sum_uget1(bench: &mut Bencher) -{ +fn chunk2x2_sum_uget1(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -68,8 +64,7 @@ fn chunk2x2_sum_uget1(bench: &mut Bencher) #[bench] #[allow(clippy::identity_op)] -fn chunk2x2_sum_get2(bench: &mut Bencher) -{ +fn chunk2x2_sum_get2(bench: &mut Bencher) { let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); diff --git a/benches/construct.rs b/benches/construct.rs index 958eaa3b..b06983b2 100644 --- a/benches/construct.rs +++ b/benches/construct.rs @@ -6,21 +6,18 @@ use test::Bencher; use ndarray::prelude::*; #[bench] -fn default_f64(bench: &mut Bencher) -{ +fn default_f64(bench: &mut Bencher) { bench.iter(|| Array::::default((128, 128))) } #[bench] -fn zeros_f64(bench: &mut Bencher) -{ +fn zeros_f64(bench: &mut Bencher) { bench.iter(|| Array::::zeros((128, 128))) } #[cfg(feature = "std")] #[bench] -fn map_regular(bench: &mut test::Bencher) -{ +fn map_regular(bench: &mut test::Bencher) { let a = Array::linspace(0.0..=127.0, 128) .into_shape_with_order((8, 16)) .unwrap(); @@ -29,8 +26,7 @@ fn map_regular(bench: &mut test::Bencher) #[cfg(feature = "std")] #[bench] -fn map_stride(bench: &mut test::Bencher) -{ +fn map_stride(bench: &mut test::Bencher) { let a = Array::linspace(0.0..=127.0, 256) .into_shape_with_order((8, 32)) .unwrap(); diff --git a/benches/gemv_gemm.rs b/benches/gemv_gemm.rs index ccd98725..0569c2ec 100644 --- a/benches/gemv_gemm.rs +++ b/benches/gemv_gemm.rs @@ -14,8 +14,7 @@ use ndarray::linalg::general_mat_vec_mul; use ndarray::LinalgScalar; #[bench] -fn gemv_64_64c(bench: &mut Bencher) -{ +fn gemv_64_64c(bench: &mut Bencher) { let a = Array::zeros((64, 64)); let (m, n) = a.dim(); let x = Array::zeros(n); @@ -26,8 +25,7 @@ fn gemv_64_64c(bench: &mut Bencher) } #[bench] -fn gemv_64_64f(bench: &mut Bencher) -{ +fn gemv_64_64f(bench: &mut Bencher) { let a = Array::zeros((64, 64).f()); let (m, n) = a.dim(); let x = Array::zeros(n); @@ -38,8 +36,7 @@ fn gemv_64_64f(bench: &mut Bencher) } #[bench] -fn gemv_64_32(bench: &mut Bencher) -{ +fn gemv_64_32(bench: &mut Bencher) { let a = Array::zeros((64, 32)); let (m, n) = a.dim(); let x = Array::zeros(n); @@ -50,19 +47,18 @@ fn gemv_64_32(bench: &mut Bencher) } #[bench] -fn cgemm_100(bench: &mut Bencher) -{ +fn cgemm_100(bench: &mut Bencher) { cgemm_bench::(100, bench); } #[bench] -fn zgemm_100(bench: &mut Bencher) -{ +fn zgemm_100(bench: &mut Bencher) { cgemm_bench::(100, bench); } fn cgemm_bench(size: usize, bench: &mut Bencher) -where A: LinalgScalar + Float +where + A: LinalgScalar + Float, { let (m, k, n) = (size, size, size); let a = Array::, _>::zeros((m, k)); diff --git a/benches/higher-order.rs b/benches/higher-order.rs index 6356687f..f2743101 100644 --- a/benches/higher-order.rs +++ b/benches/higher-order.rs @@ -12,23 +12,20 @@ const Y: usize = 16; #[cfg(feature = "std")] #[bench] -fn map_regular(bench: &mut Bencher) -{ +fn map_regular(bench: &mut Bencher) { let a = Array::linspace(0.0..=127.0, N) .into_shape_with_order((X, Y)) .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); } -pub fn double_array(mut a: ArrayViewMut2<'_, f64>) -{ +pub fn double_array(mut a: ArrayViewMut2<'_, f64>) { a *= 2.0; } #[cfg(feature = "std")] #[bench] -fn map_stride_double_f64(bench: &mut Bencher) -{ +fn map_stride_double_f64(bench: &mut Bencher) { let mut a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); @@ -40,8 +37,7 @@ fn map_stride_double_f64(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn map_stride_f64(bench: &mut Bencher) -{ +fn map_stride_f64(bench: &mut Bencher) { let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); @@ -51,8 +47,7 @@ fn map_stride_f64(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn map_stride_u32(bench: &mut Bencher) -{ +fn map_stride_u32(bench: &mut Bencher) { let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); @@ -63,8 +58,7 @@ fn map_stride_u32(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn fold_axis(bench: &mut Bencher) -{ +fn fold_axis(bench: &mut Bencher) { let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); @@ -75,8 +69,7 @@ const MA: usize = 64; const MASZ: usize = MA * MA; #[bench] -fn map_axis_0(bench: &mut Bencher) -{ +fn map_axis_0(bench: &mut Bencher) { let a = Array::from_iter(0..MASZ as i32) .into_shape_with_order([MA, MA]) .unwrap(); @@ -84,8 +77,7 @@ fn map_axis_0(bench: &mut Bencher) } #[bench] -fn map_axis_1(bench: &mut Bencher) -{ +fn map_axis_1(bench: &mut Bencher) { let a = Array::from_iter(0..MASZ as i32) .into_shape_with_order([MA, MA]) .unwrap(); diff --git a/benches/iter.rs b/benches/iter.rs index 0e18f123..da18c0d6 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -11,15 +11,13 @@ use ndarray::Slice; use ndarray::{FoldWhile, Zip}; #[bench] -fn iter_sum_2d_regular(bench: &mut Bencher) -{ +fn iter_sum_2d_regular(bench: &mut Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| a.iter().sum::()); } #[bench] -fn iter_sum_2d_cutout(bench: &mut Bencher) -{ +fn iter_sum_2d_cutout(bench: &mut Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = av; @@ -27,8 +25,7 @@ fn iter_sum_2d_cutout(bench: &mut Bencher) } #[bench] -fn iter_all_2d_cutout(bench: &mut Bencher) -{ +fn iter_all_2d_cutout(bench: &mut Bencher) { let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = av; @@ -36,8 +33,7 @@ fn iter_all_2d_cutout(bench: &mut Bencher) } #[bench] -fn iter_sum_2d_transpose(bench: &mut Bencher) -{ +fn iter_sum_2d_transpose(bench: &mut Bencher) { let a = Array::::zeros((66, 66)); let a = a.t(); bench.iter(|| a.iter().sum::()); @@ -45,8 +41,7 @@ fn iter_sum_2d_transpose(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_filter_sum_2d_u32(bench: &mut Bencher) -{ +fn iter_filter_sum_2d_u32(bench: &mut Bencher) { let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); @@ -56,8 +51,7 @@ fn iter_filter_sum_2d_u32(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_filter_sum_2d_f32(bench: &mut Bencher) -{ +fn iter_filter_sum_2d_f32(bench: &mut Bencher) { let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); @@ -67,8 +61,7 @@ fn iter_filter_sum_2d_f32(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) -{ +fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) { let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); @@ -79,8 +72,7 @@ fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) -{ +fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); @@ -91,8 +83,7 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_rev_step_by_contiguous(bench: &mut Bencher) -{ +fn iter_rev_step_by_contiguous(bench: &mut Bencher) { let a = Array::linspace(0.0..=1.0, 512); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { @@ -103,8 +94,7 @@ fn iter_rev_step_by_contiguous(bench: &mut Bencher) #[cfg(feature = "std")] #[bench] -fn iter_rev_step_by_discontiguous(bench: &mut Bencher) -{ +fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { let mut a = Array::linspace(0.0..=1.0, 1024); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| { @@ -117,8 +107,7 @@ fn iter_rev_step_by_discontiguous(bench: &mut Bencher) const ZIPSZ: usize = 10_000; #[bench] -fn sum_3_std_zip1(bench: &mut Bencher) -{ +fn sum_3_std_zip1(bench: &mut Bencher) { let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -130,8 +119,7 @@ fn sum_3_std_zip1(bench: &mut Bencher) } #[bench] -fn sum_3_std_zip2(bench: &mut Bencher) -{ +fn sum_3_std_zip2(bench: &mut Bencher) { let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -144,8 +132,7 @@ fn sum_3_std_zip2(bench: &mut Bencher) } #[bench] -fn sum_3_std_zip3(bench: &mut Bencher) -{ +fn sum_3_std_zip3(bench: &mut Bencher) { let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -159,8 +146,7 @@ fn sum_3_std_zip3(bench: &mut Bencher) } #[bench] -fn vector_sum_3_std_zip(bench: &mut Bencher) -{ +fn vector_sum_3_std_zip(bench: &mut Bencher) { let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -172,8 +158,7 @@ fn vector_sum_3_std_zip(bench: &mut Bencher) } #[bench] -fn sum_3_azip(bench: &mut Bencher) -{ +fn sum_3_azip(bench: &mut Bencher) { let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -187,8 +172,7 @@ fn sum_3_azip(bench: &mut Bencher) } #[bench] -fn sum_3_azip_fold(bench: &mut Bencher) -{ +fn sum_3_azip_fold(bench: &mut Bencher) { let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -202,8 +186,7 @@ fn sum_3_azip_fold(bench: &mut Bencher) } #[bench] -fn vector_sum_3_azip(bench: &mut Bencher) -{ +fn vector_sum_3_azip(bench: &mut Bencher) { let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -214,8 +197,7 @@ fn vector_sum_3_azip(bench: &mut Bencher) }); } -fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) -{ +fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) { for i in 0..c.len() { unsafe { *c.get_unchecked_mut(i) += *a.get_unchecked(i) + *b.get_unchecked(i); @@ -224,8 +206,7 @@ fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) } #[bench] -fn vector_sum_3_zip_unchecked(bench: &mut Bencher) -{ +fn vector_sum_3_zip_unchecked(bench: &mut Bencher) { let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -235,8 +216,7 @@ fn vector_sum_3_zip_unchecked(bench: &mut Bencher) } #[bench] -fn vector_sum_3_zip_unchecked_manual(bench: &mut Bencher) -{ +fn vector_sum_3_zip_unchecked_manual(bench: &mut Bencher) { let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -256,8 +236,7 @@ const ISZ: usize = 16; const I2DSZ: usize = 64; #[bench] -fn indexed_iter_1d_ix1(bench: &mut Bencher) -{ +fn indexed_iter_1d_ix1(bench: &mut Bencher) { let mut a = Array::::zeros(I2DSZ * I2DSZ); for (i, elt) in a.indexed_iter_mut() { *elt = i as _; @@ -272,8 +251,7 @@ fn indexed_iter_1d_ix1(bench: &mut Bencher) } #[bench] -fn indexed_zip_1d_ix1(bench: &mut Bencher) -{ +fn indexed_zip_1d_ix1(bench: &mut Bencher) { let mut a = Array::::zeros(I2DSZ * I2DSZ); for (i, elt) in a.indexed_iter_mut() { *elt = i as _; @@ -288,8 +266,7 @@ fn indexed_zip_1d_ix1(bench: &mut Bencher) } #[bench] -fn indexed_iter_2d_ix2(bench: &mut Bencher) -{ +fn indexed_iter_2d_ix2(bench: &mut Bencher) { let mut a = Array::::zeros((I2DSZ, I2DSZ)); for ((i, j), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j) as _; @@ -303,8 +280,7 @@ fn indexed_iter_2d_ix2(bench: &mut Bencher) }) } #[bench] -fn indexed_zip_2d_ix2(bench: &mut Bencher) -{ +fn indexed_zip_2d_ix2(bench: &mut Bencher) { let mut a = Array::::zeros((I2DSZ, I2DSZ)); for ((i, j), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j) as _; @@ -319,8 +295,7 @@ fn indexed_zip_2d_ix2(bench: &mut Bencher) } #[bench] -fn indexed_iter_3d_ix3(bench: &mut Bencher) -{ +fn indexed_iter_3d_ix3(bench: &mut Bencher) { let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; @@ -335,8 +310,7 @@ fn indexed_iter_3d_ix3(bench: &mut Bencher) } #[bench] -fn indexed_zip_3d_ix3(bench: &mut Bencher) -{ +fn indexed_zip_3d_ix3(bench: &mut Bencher) { let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; @@ -351,8 +325,7 @@ fn indexed_zip_3d_ix3(bench: &mut Bencher) } #[bench] -fn indexed_iter_3d_dyn(bench: &mut Bencher) -{ +fn indexed_iter_3d_dyn(bench: &mut Bencher) { let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; @@ -368,31 +341,27 @@ fn indexed_iter_3d_dyn(bench: &mut Bencher) } #[bench] -fn iter_sum_1d_strided_fold(bench: &mut Bencher) -{ +fn iter_sum_1d_strided_fold(bench: &mut Bencher) { let mut a = Array::::ones(10240); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| a.iter().sum::()); } #[bench] -fn iter_sum_1d_strided_rfold(bench: &mut Bencher) -{ +fn iter_sum_1d_strided_rfold(bench: &mut Bencher) { let mut a = Array::::ones(10240); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| a.iter().rfold(0, |acc, &x| acc + x)); } #[bench] -fn iter_axis_iter_sum(bench: &mut Bencher) -{ +fn iter_axis_iter_sum(bench: &mut Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| a.axis_iter(Axis(0)).map(|plane| plane.sum()).sum::()); } #[bench] -fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) -{ +fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| { a.axis_chunks_iter(Axis(0), 1) @@ -402,8 +371,7 @@ fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) } #[bench] -fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) -{ +fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) { let a = Array::::zeros((64, 64)); bench.iter(|| { a.axis_chunks_iter(Axis(0), 5) @@ -412,24 +380,21 @@ fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) }); } -pub fn zip_mut_with(data: &Array3, out: &mut Array3) -{ +pub fn zip_mut_with(data: &Array3, out: &mut Array3) { out.zip_mut_with(data, |o, &i| { *o = i; }); } #[bench] -fn zip_mut_with_cc(b: &mut Bencher) -{ +fn zip_mut_with_cc(b: &mut Bencher) { let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ)); let mut out = Array3::zeros(data.dim()); b.iter(|| zip_mut_with(&data, &mut out)); } #[bench] -fn zip_mut_with_ff(b: &mut Bencher) -{ +fn zip_mut_with_ff(b: &mut Bencher) { let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ).f()); let mut out = Array3::zeros(data.dim().f()); b.iter(|| zip_mut_with(&data, &mut out)); diff --git a/benches/numeric.rs b/benches/numeric.rs index 5dcde52d..ea6e58e6 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -11,8 +11,7 @@ const Y: usize = 16; #[cfg(feature = "std")] #[bench] -fn clip(bench: &mut Bencher) -{ +fn clip(bench: &mut Bencher) { let mut a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); diff --git a/benches/par_rayon.rs b/benches/par_rayon.rs index 95b51427..393bc215 100644 --- a/benches/par_rayon.rs +++ b/benches/par_rayon.rs @@ -12,8 +12,7 @@ use ndarray::Zip; const EXP_N: usize = 256; const ADDN: usize = 512; -fn set_threads() -{ +fn set_threads() { // Consider setting a fixed number of threads here, for example to avoid // oversubscribing on hyperthreaded cores. // let n = 4; @@ -21,8 +20,7 @@ fn set_threads() } #[bench] -fn map_exp_regular(bench: &mut Bencher) -{ +fn map_exp_regular(bench: &mut Bencher) { let mut a = Array2::::zeros((EXP_N, EXP_N)); a.swap_axes(0, 1); bench.iter(|| { @@ -31,8 +29,7 @@ fn map_exp_regular(bench: &mut Bencher) } #[bench] -fn rayon_exp_regular(bench: &mut Bencher) -{ +fn rayon_exp_regular(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((EXP_N, EXP_N)); a.swap_axes(0, 1); @@ -44,22 +41,19 @@ fn rayon_exp_regular(bench: &mut Bencher) const FASTEXP: usize = EXP_N; #[inline] -fn fastexp(x: f64) -> f64 -{ +fn fastexp(x: f64) -> f64 { let x = 1. + x / 1024.; x.powi(1024) } #[bench] -fn map_fastexp_regular(bench: &mut Bencher) -{ +fn map_fastexp_regular(bench: &mut Bencher) { let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| a.mapv_inplace(fastexp)); } #[bench] -fn rayon_fastexp_regular(bench: &mut Bencher) -{ +fn rayon_fastexp_regular(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -68,16 +62,14 @@ fn rayon_fastexp_regular(bench: &mut Bencher) } #[bench] -fn map_fastexp_cut(bench: &mut Bencher) -{ +fn map_fastexp_cut(bench: &mut Bencher) { let mut a = Array2::::zeros((FASTEXP, FASTEXP)); let mut a = a.slice_mut(s![.., ..-1]); bench.iter(|| a.mapv_inplace(fastexp)); } #[bench] -fn rayon_fastexp_cut(bench: &mut Bencher) -{ +fn rayon_fastexp_cut(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); let mut a = a.slice_mut(s![.., ..-1]); @@ -87,8 +79,7 @@ fn rayon_fastexp_cut(bench: &mut Bencher) } #[bench] -fn map_fastexp_by_axis(bench: &mut Bencher) -{ +fn map_fastexp_by_axis(bench: &mut Bencher) { let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { for mut sheet in a.axis_iter_mut(Axis(0)) { @@ -98,8 +89,7 @@ fn map_fastexp_by_axis(bench: &mut Bencher) } #[bench] -fn rayon_fastexp_by_axis(bench: &mut Bencher) -{ +fn rayon_fastexp_by_axis(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -110,8 +100,7 @@ fn rayon_fastexp_by_axis(bench: &mut Bencher) } #[bench] -fn rayon_fastexp_zip(bench: &mut Bencher) -{ +fn rayon_fastexp_zip(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -122,8 +111,7 @@ fn rayon_fastexp_zip(bench: &mut Bencher) } #[bench] -fn add(bench: &mut Bencher) -{ +fn add(bench: &mut Bencher) { let mut a = Array2::::zeros((ADDN, ADDN)); let b = Array2::::zeros((ADDN, ADDN)); let c = Array2::::zeros((ADDN, ADDN)); @@ -136,8 +124,7 @@ fn add(bench: &mut Bencher) } #[bench] -fn rayon_add(bench: &mut Bencher) -{ +fn rayon_add(bench: &mut Bencher) { set_threads(); let mut a = Array2::::zeros((ADDN, ADDN)); let b = Array2::::zeros((ADDN, ADDN)); @@ -154,29 +141,25 @@ const COLL_STRING_N: usize = 64; const COLL_F64_N: usize = 128; #[bench] -fn vec_string_collect(bench: &mut test::Bencher) -{ +fn vec_string_collect(bench: &mut test::Bencher) { let v = vec![""; COLL_STRING_N * COLL_STRING_N]; bench.iter(|| v.iter().map(|s| s.to_owned()).collect::>()); } #[bench] -fn array_string_collect(bench: &mut test::Bencher) -{ +fn array_string_collect(bench: &mut test::Bencher) { let v = Array::from_elem((COLL_STRING_N, COLL_STRING_N), ""); bench.iter(|| Zip::from(&v).par_map_collect(|s| s.to_owned())); } #[bench] -fn vec_f64_collect(bench: &mut test::Bencher) -{ +fn vec_f64_collect(bench: &mut test::Bencher) { let v = vec![1.; COLL_F64_N * COLL_F64_N]; bench.iter(|| v.iter().map(|s| s + 1.).collect::>()); } #[bench] -fn array_f64_collect(bench: &mut test::Bencher) -{ +fn array_f64_collect(bench: &mut test::Bencher) { let v = Array::from_elem((COLL_F64_N, COLL_F64_N), 1.); bench.iter(|| Zip::from(&v).par_map_collect(|s| s + 1.)); } diff --git a/benches/reserve.rs b/benches/reserve.rs index 14ebf9f1..422d1283 100644 --- a/benches/reserve.rs +++ b/benches/reserve.rs @@ -6,8 +6,7 @@ use test::Bencher; use ndarray::prelude::*; #[bench] -fn push_reserve(bench: &mut Bencher) -{ +fn push_reserve(bench: &mut Bencher) { let ones: Array = array![1f32]; bench.iter(|| { let mut a: Array = array![]; @@ -19,8 +18,7 @@ fn push_reserve(bench: &mut Bencher) } #[bench] -fn push_no_reserve(bench: &mut Bencher) -{ +fn push_no_reserve(bench: &mut Bencher) { let ones: Array = array![1f32]; bench.iter(|| { let mut a: Array = array![]; diff --git a/benches/to_shape.rs b/benches/to_shape.rs index f056a985..7c9f9144 100644 --- a/benches/to_shape.rs +++ b/benches/to_shape.rs @@ -7,88 +7,77 @@ use ndarray::prelude::*; use ndarray::Order; #[bench] -fn to_shape2_1(bench: &mut Bencher) -{ +fn to_shape2_1(bench: &mut Bencher) { let a = Array::::zeros((4, 5)); let view = a.view(); bench.iter(|| view.to_shape(4 * 5).unwrap()); } #[bench] -fn to_shape2_2_same(bench: &mut Bencher) -{ +fn to_shape2_2_same(bench: &mut Bencher) { let a = Array::::zeros((4, 5)); let view = a.view(); bench.iter(|| view.to_shape((4, 5)).unwrap()); } #[bench] -fn to_shape2_2_flip(bench: &mut Bencher) -{ +fn to_shape2_2_flip(bench: &mut Bencher) { let a = Array::::zeros((4, 5)); let view = a.view(); bench.iter(|| view.to_shape((5, 4)).unwrap()); } #[bench] -fn to_shape2_3(bench: &mut Bencher) -{ +fn to_shape2_3(bench: &mut Bencher) { let a = Array::::zeros((4, 5)); let view = a.view(); bench.iter(|| view.to_shape((2, 5, 2)).unwrap()); } #[bench] -fn to_shape3_1(bench: &mut Bencher) -{ +fn to_shape3_1(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5)); let view = a.view(); bench.iter(|| view.to_shape(3 * 4 * 5).unwrap()); } #[bench] -fn to_shape3_2_order(bench: &mut Bencher) -{ +fn to_shape3_2_order(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5)); let view = a.view(); bench.iter(|| view.to_shape((12, 5)).unwrap()); } #[bench] -fn to_shape3_2_outoforder(bench: &mut Bencher) -{ +fn to_shape3_2_outoforder(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5)); let view = a.view(); bench.iter(|| view.to_shape((4, 15)).unwrap()); } #[bench] -fn to_shape3_3c(bench: &mut Bencher) -{ +fn to_shape3_3c(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5)); let view = a.view(); bench.iter(|| view.to_shape((3, 4, 5)).unwrap()); } #[bench] -fn to_shape3_3f(bench: &mut Bencher) -{ +fn to_shape3_3f(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5).f()); let view = a.view(); bench.iter(|| view.to_shape(((3, 4, 5), Order::F)).unwrap()); } #[bench] -fn to_shape3_4c(bench: &mut Bencher) -{ +fn to_shape3_4c(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5)); let view = a.view(); bench.iter(|| view.to_shape(((2, 3, 2, 5), Order::C)).unwrap()); } #[bench] -fn to_shape3_4f(bench: &mut Bencher) -{ +fn to_shape3_4f(bench: &mut Bencher) { let a = Array::::zeros((3, 4, 5).f()); let view = a.view(); bench.iter(|| view.to_shape(((2, 3, 2, 5), Order::F)).unwrap()); diff --git a/benches/zip.rs b/benches/zip.rs index 46149731..1194e450 100644 --- a/benches/zip.rs +++ b/benches/zip.rs @@ -33,8 +33,7 @@ where z22.for_each(f); } -pub fn zip_indexed(data: &Array3, out: &mut Array3) -{ +pub fn zip_indexed(data: &Array3, out: &mut Array3) { Zip::indexed(data).and(out).for_each(|idx, &i, o| { let _ = black_box(idx); *o = i; @@ -45,56 +44,49 @@ pub fn zip_indexed(data: &Array3, out: &mut Array3) const SZ3: (usize, usize, usize) = (100, 110, 100); #[bench] -fn zip_cc(b: &mut Bencher) -{ +fn zip_cc(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3); let mut out = Array3::zeros(data.dim()); b.iter(|| zip_copy(&data, &mut out)); } #[bench] -fn zip_cf(b: &mut Bencher) -{ +fn zip_cf(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3); let mut out = Array3::zeros(data.dim().f()); b.iter(|| zip_copy(&data, &mut out)); } #[bench] -fn zip_fc(b: &mut Bencher) -{ +fn zip_fc(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3.f()); let mut out = Array3::zeros(data.dim()); b.iter(|| zip_copy(&data, &mut out)); } #[bench] -fn zip_ff(b: &mut Bencher) -{ +fn zip_ff(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3.f()); let mut out = Array3::zeros(data.dim().f()); b.iter(|| zip_copy(&data, &mut out)); } #[bench] -fn zip_indexed_cc(b: &mut Bencher) -{ +fn zip_indexed_cc(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3); let mut out = Array3::zeros(data.dim()); b.iter(|| zip_indexed(&data, &mut out)); } #[bench] -fn zip_indexed_ff(b: &mut Bencher) -{ +fn zip_indexed_ff(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3.f()); let mut out = Array3::zeros(data.dim().f()); b.iter(|| zip_indexed(&data, &mut out)); } #[bench] -fn slice_zip_cc(b: &mut Bencher) -{ +fn slice_zip_cc(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3); let mut out = Array3::zeros(data.dim()); let data = data.slice(s![1.., 1.., 1..]); @@ -103,8 +95,7 @@ fn slice_zip_cc(b: &mut Bencher) } #[bench] -fn slice_zip_ff(b: &mut Bencher) -{ +fn slice_zip_ff(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3.f()); let mut out = Array3::zeros(data.dim().f()); let data = data.slice(s![1.., 1.., 1..]); @@ -113,8 +104,7 @@ fn slice_zip_ff(b: &mut Bencher) } #[bench] -fn slice_split_zip_cc(b: &mut Bencher) -{ +fn slice_split_zip_cc(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3); let mut out = Array3::zeros(data.dim()); let data = data.slice(s![1.., 1.., 1..]); @@ -123,8 +113,7 @@ fn slice_split_zip_cc(b: &mut Bencher) } #[bench] -fn slice_split_zip_ff(b: &mut Bencher) -{ +fn slice_split_zip_ff(b: &mut Bencher) { let data: Array3 = Array3::zeros(SZ3.f()); let mut out = Array3::zeros(data.dim().f()); let data = data.slice(s![1.., 1.., 1..]); diff --git a/crates/blas-mock-tests/tests/use-blas.rs b/crates/blas-mock-tests/tests/use-blas.rs index 217508af..a259c515 100644 --- a/crates/blas-mock-tests/tests/use-blas.rs +++ b/crates/blas-mock-tests/tests/use-blas.rs @@ -10,8 +10,7 @@ use ndarray_gen::array_builder::ArrayBuilder; use itertools::iproduct; #[test] -fn test_gen_mat_mul_uses_blas() -{ +fn test_gen_mat_mul_uses_blas() { let alpha = 1.0; let beta = 0.0; diff --git a/crates/blas-tests/src/lib.rs b/crates/blas-tests/src/lib.rs index fc031eed..ad453c8e 100644 --- a/crates/blas-tests/src/lib.rs +++ b/crates/blas-tests/src/lib.rs @@ -1,4 +1,6 @@ #[cfg(not(feature = "blas-src"))] -compile_error!("Missing backend: could not compile. +compile_error!( + "Missing backend: could not compile. Help: For this testing crate, select one of the blas backend features, for example \ - openblas-system"); + openblas-system" +); diff --git a/crates/blas-tests/tests/dyn.rs b/crates/blas-tests/tests/dyn.rs index 6c0fd975..f7f5b00b 100644 --- a/crates/blas-tests/tests/dyn.rs +++ b/crates/blas-tests/tests/dyn.rs @@ -2,8 +2,7 @@ extern crate blas_src; use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2}; #[test] -fn test_arrayd_dot_2d() -{ +fn test_arrayd_dot_2d() { let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); @@ -22,8 +21,7 @@ fn test_arrayd_dot_2d() } #[test] -fn test_arrayd_dot_1d() -{ +fn test_arrayd_dot_1d() { // Test 1D array dot product let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); @@ -38,8 +36,7 @@ fn test_arrayd_dot_1d() #[test] #[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] -fn test_arrayd_dot_3d() -{ +fn test_arrayd_dot_3d() { // Test that 3D arrays are not supported let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); @@ -49,8 +46,7 @@ fn test_arrayd_dot_3d() #[test] #[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] -fn test_arrayd_dot_incompatible_dims() -{ +fn test_arrayd_dot_incompatible_dims() { // Test arrays with incompatible dimensions let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); @@ -59,8 +55,7 @@ fn test_arrayd_dot_incompatible_dims() } #[test] -fn test_arrayd_dot_matrix_vector() -{ +fn test_arrayd_dot_matrix_vector() { // Test matrix-vector multiplication let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index f604ae09..1ec33d12 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -22,8 +22,7 @@ use num_complex::Complex64; use num_traits::Num; #[test] -fn mat_vec_product_1d() -{ +fn mat_vec_product_1d() { let a = arr2(&[[1.], [2.]]); let b = arr1(&[1., 2.]); let ans = arr1(&[5.]); @@ -31,8 +30,7 @@ fn mat_vec_product_1d() } #[test] -fn mat_vec_product_1d_broadcast() -{ +fn mat_vec_product_1d_broadcast() { let a = arr2(&[[1.], [2.], [3.]]); let b = arr1(&[1.]); let b = b.broadcast(3).unwrap(); @@ -41,8 +39,7 @@ fn mat_vec_product_1d_broadcast() } #[test] -fn mat_vec_product_1d_inverted_axis() -{ +fn mat_vec_product_1d_inverted_axis() { let a = arr2(&[[1.], [2.], [3.]]); let mut b = arr1(&[1., 2., 3.]); b.invert_axis(Axis(0)); @@ -51,28 +48,23 @@ fn mat_vec_product_1d_inverted_axis() assert_eq!(a.t().dot(&b), ans); } -fn range_mat(m: Ix, n: Ix) -> Array2 -{ +fn range_mat(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } -fn range_mat_complex(m: Ix, n: Ix) -> Array2 -{ +fn range_mat_complex(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } -fn range_mat_complex64(m: Ix, n: Ix) -> Array2 -{ +fn range_mat_complex64(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } -fn range1_mat64(m: Ix) -> Array1 -{ +fn range1_mat64(m: Ix) -> Array1 { ArrayBuilder::new(m).build() } -fn range_i32(m: Ix, n: Ix) -> Array2 -{ +fn range_i32(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } @@ -145,8 +137,7 @@ where // Check that matrix multiplication of contiguous matrices returns a // matrix with the same order #[test] -fn mat_mul_order() -{ +fn mat_mul_order() { let (m, n, k) = (50, 50, 50); let a = range_mat::(m, n); let b = range_mat::(n, k); @@ -165,8 +156,7 @@ fn mat_mul_order() // Check that matrix multiplication // supports broadcast arrays. #[test] -fn mat_mul_broadcast() -{ +fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); let a = range_mat::(m, n); let x1 = 1.; @@ -185,8 +175,7 @@ fn mat_mul_broadcast() // Check that matrix multiplication supports reversed axes #[test] -fn mat_mul_rev() -{ +fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); let a = range_mat::(m, n); let b = range_mat::(n, k); @@ -202,8 +191,7 @@ fn mat_mul_rev() // Check that matrix multiplication supports arrays with zero rows or columns #[test] -fn mat_mut_zero_len() -{ +fn mat_mut_zero_len() { defmac!(mat_mul_zero_len range_mat_fn => { for n in 0..4 { for m in 0..4 { @@ -224,8 +212,7 @@ fn mat_mut_zero_len() } #[test] -fn gen_mat_mul() -{ +fn gen_mat_mul() { let alpha = -2.3; let beta = 3.14; let sizes = vec![ @@ -293,8 +280,7 @@ fn gen_mat_mul() // Test y = A x where A is f-order #[test] -fn gemm_64_1_f() -{ +fn gemm_64_1_f() { let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 @@ -306,8 +292,7 @@ fn gemm_64_1_f() } #[test] -fn gemm_c64_1_f() -{ +fn gemm_c64_1_f() { let a = range_mat_complex64(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 @@ -315,17 +300,11 @@ fn gemm_c64_1_f() let mut y = range_mat_complex64(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(Complex64::new(1.0, 0.), &a, &x, Complex64::new(1.0, 0.), &mut y); - assert_relative_eq!( - y.mapv(|i| i.norm_sqr()), - answer.mapv(|i| i.norm_sqr()), - epsilon = 1e-12, - max_relative = 1e-7 - ); + assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7); } #[test] -fn gemm_c32_1_f() -{ +fn gemm_c32_1_f() { let a = range_mat_complex(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 @@ -333,17 +312,11 @@ fn gemm_c32_1_f() let mut y = range_mat_complex(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(Complex32::new(1.0, 0.), &a, &x, Complex32::new(1.0, 0.), &mut y); - assert_relative_eq!( - y.mapv(|i| i.norm_sqr()), - answer.mapv(|i| i.norm_sqr()), - epsilon = 1e-12, - max_relative = 1e-7 - ); + assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7); } #[test] -fn gemm_c64_actually_complex() -{ +fn gemm_c64_actually_complex() { let mut a = range_mat_complex64(4, 4); a = a.map(|&i| if i.re > 8. { i.conj() } else { i }); let mut b = range_mat_complex64(4, 6); @@ -353,30 +326,14 @@ fn gemm_c64_actually_complex() let beta = Complex64::new(1.0, 1.0); let answer = alpha * reference_mat_mul(&a, &b) + beta * &y; general_mat_mul(alpha.clone(), &a, &b, beta.clone(), &mut y); - assert_relative_eq!( - y.mapv(|i| i.norm_sqr()), - answer.mapv(|i| i.norm_sqr()), - epsilon = 1e-12, - max_relative = 1e-7 - ); + assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7); } #[test] -fn gen_mat_vec_mul() -{ +fn gen_mat_vec_mul() { let alpha = -2.3; let beta = 3.14; - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; + let sizes = vec![(4, 4), (8, 8), (17, 15), (4, 17), (17, 3), (19, 18), (16, 17), (15, 16), (67, 63)]; // test different strides for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { @@ -406,19 +363,8 @@ fn gen_mat_vec_mul() } #[test] -fn vec_mat_mul() -{ - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; +fn vec_mat_mul() { + let sizes = vec![(4, 4), (8, 8), (17, 15), (4, 17), (17, 3), (19, 18), (16, 17), (15, 16), (67, 63)]; // test different strides for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 49d05fbb..0fd1f3ad 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -40,9 +40,9 @@ export_tests = [] # Upstream burn crates — vendored at pinned commit, we only override our additions. # Our changes: crates/burn/src/ops/tensor.rs (try_vml_unary + 4 SIMD wires) # crates/burn/src/ops/activation.rs (fused sigmoid) -burn-backend = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } -burn-std = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } -burn-ir = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } +burn-backend = { git = "https://github.com/AdaWorldAPI/burn.git", rev = "9b2b671", default-features = false } +burn-std = { git = "https://github.com/AdaWorldAPI/burn.git", rev = "9b2b671", default-features = false } +burn-ir = { git = "https://github.com/AdaWorldAPI/burn.git", rev = "9b2b671", default-features = false } # ndarray — uses our workspace root (adaworldapi/ndarray with SIMD + HPC) ndarray = { path = "../..", default-features = false } diff --git a/crates/burn/upstream b/crates/burn/upstream index 76299209..9b2b6712 160000 --- a/crates/burn/upstream +++ b/crates/burn/upstream @@ -1 +1 @@ -Subproject commit 76299209e63b03236b5bb9d51ae45a22404cacaf +Subproject commit 9b2b67127b0fbb5387021faf540b7b12b9c4e943 diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs index 9351aadc..9c5e9ce7 100644 --- a/crates/ndarray-gen/src/array_builder.rs +++ b/crates/ndarray-gen/src/array_builder.rs @@ -14,8 +14,7 @@ use ndarray::Order; use num_traits::Num; #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct ArrayBuilder -{ +pub struct ArrayBuilder { dim: D, memory_order: Order, generator: ElementGenerator, @@ -23,26 +22,23 @@ pub struct ArrayBuilder /// How to generate elements #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ElementGenerator -{ +pub enum ElementGenerator { Sequential, Checkerboard, Zero, } -impl Default for ArrayBuilder -{ - fn default() -> Self - { +impl Default for ArrayBuilder { + fn default() -> Self { Self::new(D::zeros(D::NDIM.unwrap_or(1))) } } impl ArrayBuilder -where D: Dimension +where + D: Dimension, { - pub fn new(dim: impl IntoDimension) -> Self - { + pub fn new(dim: impl IntoDimension) -> Self { ArrayBuilder { dim: dim.into_dimension(), memory_order: Order::C, @@ -50,26 +46,26 @@ where D: Dimension } } - pub fn memory_order(mut self, order: Order) -> Self - { + pub fn memory_order(mut self, order: Order) -> Self { self.memory_order = order; self } - pub fn generator(mut self, generator: ElementGenerator) -> Self - { + pub fn generator(mut self, generator: ElementGenerator) -> Self { self.generator = generator; self } pub fn build(self) -> Array - where T: Num + Clone + where + T: Num + Clone, { let zero = T::zero(); let size = self.dim.size(); (match self.generator { - ElementGenerator::Sequential => - Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)), + ElementGenerator::Sequential => { + Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)) + } ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()), ElementGenerator::Zero => Array::zeros(size), }) @@ -79,8 +75,7 @@ where D: Dimension } #[test] -fn test_order() -{ +fn test_order() { let (m, n) = (12, 13); let c = ArrayBuilder::new((m, n)) .memory_order(Order::C) diff --git a/crates/numeric-tests/tests/accuracy.rs b/crates/numeric-tests/tests/accuracy.rs index db10d57c..e5172580 100644 --- a/crates/numeric-tests/tests/accuracy.rs +++ b/crates/numeric-tests/tests/accuracy.rs @@ -23,7 +23,8 @@ use rand_distr::{Distribution, Normal, StandardNormal}; use approx::{assert_abs_diff_eq, assert_relative_eq}; fn kahan_sum(iter: impl Iterator) -> A -where A: LinalgScalar +where + A: LinalgScalar, { let mut sum = A::zero(); let mut compensation = A::zero(); @@ -84,8 +85,7 @@ where } #[test] -fn accurate_eye_f32() -{ +fn accurate_eye_f32() { let rng = &mut SmallRng::from_os_rng(); for i in 0..20 { let eye = Array::eye(i); @@ -112,8 +112,7 @@ fn accurate_eye_f32() } #[test] -fn accurate_eye_f64() -{ +fn accurate_eye_f64() { let rng = &mut SmallRng::from_os_rng(); let abs_tol = 1e-15; for i in 0..20 { @@ -141,26 +140,22 @@ fn accurate_eye_f64() } #[test] -fn accurate_mul_f32_dot() -{ +fn accurate_mul_f32_dot() { accurate_mul_float_general::(1e-5, false); } #[test] -fn accurate_mul_f32_general() -{ +fn accurate_mul_f32_general() { accurate_mul_float_general::(1e-5, true); } #[test] -fn accurate_mul_f64_dot() -{ +fn accurate_mul_f64_dot() { accurate_mul_float_general::(1e-14, false); } #[test] -fn accurate_mul_f64_general() -{ +fn accurate_mul_f64_general() { accurate_mul_float_general::(1e-14, true); } @@ -170,7 +165,8 @@ fn accurate_mul_f64_general() fn random_matrix_mul( rng: &mut SmallRng, use_stride: bool, use_general: bool, generator: fn(Ix2, &mut SmallRng) -> Array2, ) -> (Array2, Array2) -where A: LinalgScalar +where + A: LinalgScalar, { let m = rng.random_range(15..128); let k = rng.random_range(15..128); @@ -216,21 +212,24 @@ where let diff = &c - &reference; let max_diff = diff.iter().copied().fold(A::zero(), A::max); let max_elt = reference.iter().copied().fold(A::zero(), A::max); - println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff/max_elt).as_()); - assert!((max_diff / max_elt).as_() < limit, - "Expected relative norm diff < {:e}, found {:?} / {:?}", limit, max_diff, max_elt); + println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff / max_elt).as_()); + assert!( + (max_diff / max_elt).as_() < limit, + "Expected relative norm diff < {:e}, found {:?} / {:?}", + limit, + max_diff, + max_elt + ); } } #[test] -fn accurate_mul_complex32() -{ +fn accurate_mul_complex32() { accurate_mul_complex_general::(1e-5); } #[test] -fn accurate_mul_complex64() -{ +fn accurate_mul_complex64() { accurate_mul_complex_general::(1e-14); } @@ -249,15 +248,19 @@ where let max_elt = |elt: &Complex<_>| A::max(A::abs(elt.re), A::abs(elt.im)); let max_diff = diff.iter().map(max_elt).fold(A::zero(), A::max); let max_elt = reference.iter().map(max_elt).fold(A::zero(), A::max); - println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff/max_elt).as_()); - assert!((max_diff / max_elt).as_() < limit, - "Expected relative norm diff < {:e}, found {:?} / {:?}", limit, max_diff, max_elt); + println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff / max_elt).as_()); + assert!( + (max_diff / max_elt).as_() < limit, + "Expected relative norm diff < {:e}, found {:?} / {:?}", + limit, + max_diff, + max_elt + ); } } #[test] -fn accurate_mul_with_column_f64() -{ +fn accurate_mul_with_column_f64() { // pick a few random sizes let rng = &mut SmallRng::from_os_rng(); for i in 0..10 { diff --git a/crates/p64/benches/p64_bench.rs b/crates/p64/benches/p64_bench.rs index a4e76b91..3a6be5eb 100644 --- a/crates/p64/benches/p64_bench.rs +++ b/crates/p64/benches/p64_bench.rs @@ -3,31 +3,21 @@ use p64::*; fn make_heels() -> HeelPlanes { HeelPlanes::new([ - 0xAAAA_AAAA_AAAA_AAAA, - 0x5555_5555_5555_5555, - 0xFFFF_0000_FFFF_0000, - 0x0000_FFFF_0000_FFFF, - 0xFF00_FF00_FF00_FF00, - 0x00FF_00FF_00FF_00FF, - 0xF0F0_F0F0_F0F0_F0F0, - 0x0F0F_0F0F_0F0F_0F0F, + 0xAAAA_AAAA_AAAA_AAAA, 0x5555_5555_5555_5555, 0xFFFF_0000_FFFF_0000, 0x0000_FFFF_0000_FFFF, + 0xFF00_FF00_FF00_FF00, 0x00FF_00FF_00FF_00FF, 0xF0F0_F0F0_F0F0_F0F0, 0x0F0F_0F0F_0F0F_0F0F, ]) } fn bench_expand(c: &mut Criterion) { let heels = make_heels(); - c.bench_function("expand_8_to_64", |b| { - b.iter(|| black_box(black_box(&heels).expand())) - }); + c.bench_function("expand_8_to_64", |b| b.iter(|| black_box(black_box(&heels).expand()))); } fn bench_attend(c: &mut Criterion) { let palette = make_heels().expand(); let query = 0xDEAD_BEEF_CAFE_BABEu64; - c.bench_function("attend_single", |b| { - b.iter(|| black_box(black_box(&palette).attend(black_box(query), 16))) - }); + c.bench_function("attend_single", |b| b.iter(|| black_box(black_box(&palette).attend(black_box(query), 16)))); } fn bench_attend_batch(c: &mut Criterion) { @@ -47,18 +37,14 @@ fn bench_moe_gate(c: &mut Criterion) { let heels = make_heels(); let query = 0xDEAD_BEEF_CAFE_BABEu64; - c.bench_function("moe_gate", |b| { - b.iter(|| black_box(black_box(&heels).moe_gate(black_box(query), 20))) - }); + c.bench_function("moe_gate", |b| b.iter(|| black_box(black_box(&heels).moe_gate(black_box(query), 20)))); } fn bench_soft_moe(c: &mut Criterion) { let heels = make_heels(); let query = 0xDEAD_BEEF_CAFE_BABEu64; - c.bench_function("soft_moe", |b| { - b.iter(|| black_box(black_box(&heels).soft_moe(black_box(query), 20))) - }); + c.bench_function("soft_moe", |b| b.iter(|| black_box(black_box(&heels).soft_moe(black_box(query), 20)))); } fn bench_denoise(c: &mut Criterion) { @@ -74,19 +60,11 @@ fn bench_nearest_k(c: &mut Criterion) { let palette = make_heels().expand(); let query = 0xDEAD_BEEF_CAFE_BABEu64; - c.bench_function("nearest_k_8", |b| { - b.iter(|| black_box(black_box(&palette).nearest_k(black_box(query), 8))) - }); + c.bench_function("nearest_k_8", |b| b.iter(|| black_box(black_box(&palette).nearest_k(black_box(query), 8)))); } criterion_group!( - benches, - bench_expand, - bench_attend, - bench_attend_batch, - bench_moe_gate, - bench_soft_moe, - bench_denoise, + benches, bench_expand, bench_attend, bench_attend_batch, bench_moe_gate, bench_soft_moe, bench_denoise, bench_nearest_k, ); criterion_main!(benches); diff --git a/crates/p64/src/lib.rs b/crates/p64/src/lib.rs index 2857db63..784bafd3 100644 --- a/crates/p64/src/lib.rs +++ b/crates/p64/src/lib.rs @@ -115,12 +115,7 @@ impl HeelPlanes { // 7 payload planes: 4 bytes each → expand to 64 bits via golden stepping for i in 0..7 { let start = 1 + (i * 4); - let bytes = [ - data[start] as u8, - data[start + 1] as u8, - data[start + 2] as u8, - data[start + 3] as u8, - ]; + let bytes = [data[start] as u8, data[start + 1] as u8, data[start + 2] as u8, data[start + 3] as u8]; let val = u32::from_le_bytes(bytes); // Spread 32 bits across 64 bits: each input bit controls 2 output bits // via golden-step interleave @@ -128,12 +123,7 @@ impl HeelPlanes { } // Plane 7: contradiction (4 bytes from data[29..33]) - let contra_bytes = [ - data[29] as u8, - data[30] as u8, - data[31] as u8, - data[32] as u8, - ]; + let contra_bytes = [data[29] as u8, data[30] as u8, data[31] as u8, data[32] as u8]; let contra_val = u32::from_le_bytes(contra_bytes); planes[7] = spread_32_to_64(contra_val); @@ -443,8 +433,7 @@ impl Palette64 { #[inline] pub fn attend(&self, query: u64, gamma: u8) -> AttentionResult { // SAFETY: LazyLock guarantees the selected kernel matches CPU features. - let (best_idx, distance, scores, fires) = - unsafe { ATTEND_KERNEL(&self.rows, query, gamma) }; + let (best_idx, distance, scores, fires) = unsafe { ATTEND_KERNEL(&self.rows, query, gamma) }; AttentionResult { best_idx, distance, @@ -526,8 +515,7 @@ impl HeelPlanes { #[inline] pub fn moe_gate(&self, query: u64, threshold: u8) -> MoeGate { // SAFETY: LazyLock guarantees the selected kernel matches CPU features. - let (active, strength, combined) = - unsafe { MOE_GATE_KERNEL(&self.planes, query, threshold) }; + let (active, strength, combined) = unsafe { MOE_GATE_KERNEL(&self.planes, query, threshold) }; MoeGate { active, strength, @@ -805,11 +793,7 @@ pub struct InferenceResult { impl Palette3D { /// Create from 8 individual palette layers. pub fn new(layers: [Palette64; 8], style: ThinkingStyle) -> Self { - Self { - layers, - style, - step: 0, - } + Self { layers, style, step: 0 } } /// Create from HeelPlanes: same expansion for all layers (then differentiate). @@ -911,7 +895,11 @@ impl Palette3D { } } } - if first { 0 } else { result } + if first { + 0 + } else { + result + } } CombineMode::Majority => { // Per-bit vote: set if >50% of active layers agree @@ -920,9 +908,7 @@ impl Palette3D { for bit in 0..64 { let mut votes = 0u8; for z in 0..8 { - if self.style.layer_mask & (1 << z) != 0 - && per_layer[z] & (1u64 << bit) != 0 - { + if self.style.layer_mask & (1 << z) != 0 && per_layer[z] & (1u64 << bit) != 0 { votes += 1; } } @@ -1019,8 +1005,7 @@ impl Palette3D { bits &= bits - 1; if j < 64 { // If j also supports things the query supports → grounded - mutual_support |= - self.layers[predicate::SUPPORTS].rows[j] & supported; + mutual_support |= self.layers[predicate::SUPPORTS].rows[j] & supported; } } let fresh = mutual_support & !self.layers[predicate::GROUNDS].rows[block_row]; @@ -1076,11 +1061,7 @@ impl Palette3D { /// Useful for: analytical → creative (when stuck), /// creative → focused (when solution emerges). pub fn transition( - &mut self, - from: ThinkingStyle, - to: ThinkingStyle, - query_row: usize, - steps: usize, + &mut self, from: ThinkingStyle, to: ThinkingStyle, query_row: usize, steps: usize, ) -> Vec { let mut results = Vec::with_capacity(steps); @@ -1097,8 +1078,7 @@ impl Palette3D { // Combine mode snaps at midpoint let combine = if t < 0.5 { from.combine } else { to.combine }; let contra = if t < 0.5 { from.contra } else { to.contra }; - let density_target = - from.density_target * (1.0 - t) + to.density_target * t; + let density_target = from.density_target * (1.0 - t) + to.density_target * t; self.style = ThinkingStyle { layer_mask: mask, @@ -1237,10 +1217,7 @@ pub mod sparse256 { /// The 256 leaves are grouped into 64 blocks of 4. /// Block (I, J) is set to 1 if ANY leaf-pair across blocks can interact. /// This is conservative: zero false negatives, possible false positives. - pub fn from_clam_leaves( - leaves: &[LeafCluster], - distances: &PairwiseDistances, - ) -> (Palette64, SparsityStats) { + pub fn from_clam_leaves(leaves: &[LeafCluster], distances: &PairwiseDistances) -> (Palette64, SparsityStats) { let n = leaves.len().min(256); let n_blocks = (n + 3) / 4; // ceil(n/4), max 64 @@ -1289,7 +1266,11 @@ pub mod sparse256 { active_blocks: interactions, pruned_blocks: pruned, density: if total > 0 { - palette.rows.iter().map(|r| r.count_ones() as u64).sum::() as f64 + palette + .rows + .iter() + .map(|r| r.count_ones() as u64) + .sum::() as f64 / (n_blocks * n_blocks) as f64 } else { 0.0 @@ -1348,12 +1329,7 @@ pub mod sparse256 { /// `block_weights[block_row][block_col]` scales the 4×4 contribution. /// Weights typically come from inverse LFD (local fractal dimension): /// high LFD clusters get lower weight (they're spread out). - pub fn spmv_256_weighted( - &self, - x: &[f32; 256], - y: &mut [f32; 256], - block_weights: &[[f32; 64]; 64], - ) { + pub fn spmv_256_weighted(&self, x: &[f32; 256], y: &mut [f32; 256], block_weights: &[[f32; 64]; 64]) { y.fill(0.0); for block_row in 0..64 { let mask = self.rows[block_row]; @@ -1453,10 +1429,7 @@ pub mod sparse256 { /// This is where the LEAF level lives — LanceDB vector search, DistanceMatrix /// lookup, or BF16 dot product. The cascade doesn't know or care which. pub fn hhtl_cascade_search f32>( - palette: &Palette64, - query_row: u8, - scores: &mut [f32; 256], - score_fn: F, + palette: &Palette64, query_row: u8, scores: &mut [f32; 256], score_fn: F, ) -> usize { let heel_row = query_row / 32; let hip_row = (query_row / 4) % 8; @@ -1523,10 +1496,7 @@ mod tests { // Row 0 = HEEL[0] rotated by 0 assert_eq!(palette.rows[0], heels.planes[0]); // Row 8 = HEEL[0] rotated by GOLDEN_SHIFT_64 - assert_eq!( - palette.rows[8], - heels.planes[0].rotate_left(GOLDEN_SHIFT_64) - ); + assert_eq!(palette.rows[8], heels.planes[0].rotate_left(GOLDEN_SHIFT_64)); eprintln!("Palette constructed: 64 rows × 64 bits = {} bytes", 64 * 8); } @@ -1581,14 +1551,13 @@ mod tests { let query = 0xAAAA_AAAA_AAAA_AAAA ^ 0xFF; let result = palette.attend(query, 16); - eprintln!("Noisy query (8 bits flipped): score={}, distance={}", - result.scores[result.best_idx as usize], result.distance); + eprintln!( + "Noisy query (8 bits flipped): score={}, distance={}", + result.scores[result.best_idx as usize], result.distance + ); // Original: 32 matching bits. Flip 8 bits: ~4 matches lost, ~4 false matches gained. // Row 0 should still be best or near-best. - assert!( - result.scores[0] >= 24, - "Row 0 should still score high despite noise (score={})", result.scores[0] - ); + assert!(result.scores[0] >= 24, "Row 0 should still score high despite noise (score={})", result.scores[0]); } // ── MoE Fanout ───────────────────────────────────────────────────── @@ -1627,17 +1596,10 @@ mod tests { eprintln!("Hard MoE (OR): {:016X}", hard); eprintln!("Soft MoE (vote): {:016X}", soft); - eprintln!( - "Hard density: {}, Soft density: {}", - hard.count_ones(), - soft.count_ones() - ); + eprintln!("Hard density: {}, Soft density: {}", hard.count_ones(), soft.count_ones()); // Soft should be sparser than hard (majority vs OR) - assert!( - soft.count_ones() <= hard.count_ones(), - "Soft MoE should be sparser than hard MoE" - ); + assert!(soft.count_ones() <= hard.count_ones(), "Soft MoE should be sparser than hard MoE"); } // ── Expert diversity ─────────────────────────────────────────────── @@ -1653,10 +1615,7 @@ mod tests { eprintln!(" Max: {}", dists.iter().max().unwrap()); // With our orthogonal test patterns, diversity should be high - assert!( - mean > 20.0, - "Expert diversity should be high for orthogonal patterns" - ); + assert!(mean > 20.0, "Expert diversity should be high for orthogonal patterns"); } // ── Palette statistics ───────────────────────────────────────────── @@ -1672,10 +1631,7 @@ mod tests { eprintln!(" Max Hamming distance: {max_dist}"); // Ideal: mean ≈ 32 (half of 64 = maximally dispersed) - assert!( - mean_dist > 20.0, - "Palette rows should be well-dispersed (mean={mean_dist})" - ); + assert!(mean_dist > 20.0, "Palette rows should be well-dispersed (mean={mean_dist})"); } #[test] @@ -1717,10 +1673,7 @@ mod tests { // Final state should be exactly a palette row let is_palette_row = palette.rows.contains(&final_state); - assert!( - is_palette_row, - "Final state should be a palette entry" - ); + assert!(is_palette_row, "Final state should be a palette entry"); } #[test] @@ -1805,10 +1758,7 @@ mod tests { // Intra-group: distance=5, r+r=20 → 5<=20 → interact (1) // Inter-group: distance=100, r+r=20 → 100>20 → pruned (0) // So only 8 diagonal blocks of 8×8 should be active = 512 bits of 4096 - assert!( - stats.density < 0.20, - "Density should be low due to triangle inequality pruning" - ); + assert!(stats.density < 0.20, "Density should be low due to triangle inequality pruning"); } #[test] @@ -1840,11 +1790,7 @@ mod tests { for col in [0u8, 15, 63, 128, 200, 255] { let addr = HhtlAddress::from_256(row, col); let (r, c) = addr.to_256(); - assert_eq!( - (r, c), - (row, col), - "HHTL roundtrip failed for ({row}, {col})" - ); + assert_eq!((r, c), (row, col), "HHTL roundtrip failed for ({row}, {col})"); } } } @@ -1971,22 +1917,14 @@ mod tests { eprintln!("Steps to convergence: {}", p3d.step); // Deduction should either grow or stay same (never shrink) - assert!( - after_density >= before_density, - "Deduction should grow or maintain density" - ); + assert!(after_density >= before_density, "Deduction should grow or maintain density"); } #[test] fn style_transition() { let mut p3d = make_test_palette3d(); - let results = p3d.transition( - ThinkingStyle::ANALYTICAL, - ThinkingStyle::CREATIVE, - 0, - 8, - ); + let results = p3d.transition(ThinkingStyle::ANALYTICAL, ThinkingStyle::CREATIVE, 0, 8); eprintln!("\n=== Style Transition: Analytical → Creative ==="); for (i, r) in results.iter().enumerate() { @@ -2012,10 +1950,7 @@ mod tests { let densities = p3d.layer_densities(); eprintln!("\n=== Layer Densities ==="); - let names = [ - "CAUSES", "ENABLES", "SUPPORTS", "CONTRADICTS", - "REFINES", "ABSTRACTS", "GROUNDS", "BECOMES", - ]; + let names = ["CAUSES", "ENABLES", "SUPPORTS", "CONTRADICTS", "REFINES", "ABSTRACTS", "GROUNDS", "BECOMES"]; for (i, name) in names.iter().enumerate() { eprintln!(" {:<12} {:.4}", name, densities[i]); } diff --git a/crates/serialization-tests/tests/serialize.rs b/crates/serialization-tests/tests/serialize.rs index 478eb20e..fc595b9f 100644 --- a/crates/serialization-tests/tests/serialize.rs +++ b/crates/serialization-tests/tests/serialize.rs @@ -11,8 +11,7 @@ extern crate ron; use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn}; #[test] -fn serial_many_dim_serde() -{ +fn serial_many_dim_serde() { { let a = arr0::(2.72); let serial = serde_json::to_string(&a).unwrap(); @@ -58,8 +57,7 @@ fn serial_many_dim_serde() } #[test] -fn serial_ixdyn_serde() -{ +fn serial_ixdyn_serde() { { let a = arr0::(2.72).into_dyn(); let serial = serde_json::to_string(&a).unwrap(); @@ -98,8 +96,7 @@ fn serial_ixdyn_serde() } #[test] -fn serial_wrong_count_serde() -{ +fn serial_wrong_count_serde() { // one element too few let text = r##"{"v":1,"dim":[2,3],"data":[3,1,2.2,3.1,4]}"##; let arr = serde_json::from_str::>(text); @@ -114,8 +111,7 @@ fn serial_wrong_count_serde() } #[test] -fn serial_many_dim_serde_msgpack() -{ +fn serial_many_dim_serde_msgpack() { { let a = arr0::(2.72); @@ -178,8 +174,7 @@ fn serial_many_dim_serde_msgpack() } #[test] -fn serial_many_dim_ron() -{ +fn serial_many_dim_ron() { use ron::de::from_str as ron_deserialize; use ron::ser::to_string as ron_serialize; diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 0469747f..b430090a 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -53,8 +53,7 @@ where Ok(()) } -fn main() -{ +fn main() { let mut a = Array::::zeros((2, 3, 4)); for (i, elt) in (0..).zip(&mut a) { *elt = i; diff --git a/examples/bounds_check_elim.rs b/examples/bounds_check_elim.rs index ef20e9ad..593b5136 100644 --- a/examples/bounds_check_elim.rs +++ b/examples/bounds_check_elim.rs @@ -33,8 +33,7 @@ pub fn testvec_as_slice(a: &Vec) -> f64 { */ #[no_mangle] -pub fn test1d_single(a: &Array1, i: usize) -> f64 -{ +pub fn test1d_single(a: &Array1, i: usize) -> f64 { if i < a.len() { a[i] } else { @@ -43,8 +42,7 @@ pub fn test1d_single(a: &Array1, i: usize) -> f64 } #[no_mangle] -pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 -{ +pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 { if i < a.len() { *&mut a[i] } else { @@ -53,8 +51,7 @@ pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 } #[no_mangle] -pub fn test1d_len_of(a: &Array1) -> f64 -{ +pub fn test1d_len_of(a: &Array1) -> f64 { let mut sum = 0.; for i in 0..a.len_of(Axis(0)) { sum += a[i]; @@ -63,8 +60,7 @@ pub fn test1d_len_of(a: &Array1) -> f64 } #[no_mangle] -pub fn test1d_range(a: &Array1) -> f64 -{ +pub fn test1d_range(a: &Array1) -> f64 { let mut sum = 0.; for i in 0..a.len() { sum += a[i]; @@ -73,8 +69,7 @@ pub fn test1d_range(a: &Array1) -> f64 } #[no_mangle] -pub fn test1d_while(a: &Array1) -> f64 -{ +pub fn test1d_while(a: &Array1) -> f64 { let mut sum = 0.; let mut i = 0; while i < a.len() { @@ -85,8 +80,7 @@ pub fn test1d_while(a: &Array1) -> f64 } #[no_mangle] -pub fn test2d_ranges(a: &Array2) -> f64 -{ +pub fn test2d_ranges(a: &Array2) -> f64 { let mut sum = 0.; for i in 0..a.nrows() { for j in 0..a.ncols() { @@ -97,8 +91,7 @@ pub fn test2d_ranges(a: &Array2) -> f64 } #[no_mangle] -pub fn test2d_whiles(a: &Array2) -> f64 -{ +pub fn test2d_whiles(a: &Array2) -> f64 { let mut sum = 0.; let mut i = 0; while i < a.nrows() { diff --git a/examples/column_standardize.rs b/examples/column_standardize.rs index 329ad2cc..6a1840f0 100644 --- a/examples/column_standardize.rs +++ b/examples/column_standardize.rs @@ -2,8 +2,7 @@ use ndarray::prelude::*; #[cfg(feature = "std")] -fn main() -{ +fn main() { // This example recreates the following from python/numpy // counts -= np.mean(counts, axis=0) // counts /= np.std(counts, axis=0) diff --git a/examples/convo.rs b/examples/convo.rs index 79e8ab6b..f7b48fd4 100644 --- a/examples/convo.rs +++ b/examples/convo.rs @@ -15,7 +15,8 @@ type Kernel3x3 = [[A; 3]; 3]; #[inline(never)] #[cfg(feature = "std")] fn conv_3x3(a: &ArrayRef2, out: &mut ArrayRef2, kernel: &Kernel3x3) -where F: Float +where + F: Float, { let (n, m) = a.dim(); let (np, mp) = out.dim(); @@ -42,8 +43,7 @@ where F: Float } #[cfg(feature = "std")] -fn main() -{ +fn main() { let n = 16; let mut a = Array::zeros((n, n)); // make a circle diff --git a/examples/functions_and_traits.rs b/examples/functions_and_traits.rs index 7091a5e1..01e3be4a 100644 --- a/examples/functions_and_traits.rs +++ b/examples/functions_and_traits.rs @@ -18,8 +18,7 @@ use ndarray::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawRef}; /// /// This is probably the most common pattern for users. /// Once we have an array reference, we can go to [`RawRef`] and [`LayoutRef`] very easily. -fn takes_arrref(arr: &ArrayRef) -{ +fn takes_arrref(arr: &ArrayRef) { // Since `ArrayRef` implements `Deref` to `RawRef`, we can pass `arr` directly to a function // that takes `RawRef`. Similarly, since `RawRef` implements `Deref` to `LayoutRef`, we can pass // `arr` directly to a function that takes `LayoutRef`. @@ -40,8 +39,7 @@ fn takes_arrref(arr: &ArrayRef) /// So, ***users should only accept `&mut ArrayRef` when they want to mutate data***. /// If they just want to mutate shape and strides, use `&mut LayoutRef` or `&AsMut`. #[allow(dead_code)] -fn takes_arrref_mut(arr: &mut ArrayRef) -{ +fn takes_arrref_mut(arr: &mut ArrayRef) { // We can do everything we did with a `&ArrayRef` takes_arrref(arr); @@ -64,8 +62,7 @@ fn takes_arrref_mut(arr: &mut ArrayRef) /// /// Let's see what we can do with this array: #[allow(dead_code)] -fn takes_base(arr: &ArrayBase) -{ +fn takes_base(arr: &ArrayBase) { // First off: we can pass it to functions that accept `&ArrayRef`. // // This is always "cheap", in the sense that even if `arr` is an @@ -83,8 +80,7 @@ fn takes_base(arr: &ArrayBase) /// Now, let's take a mutable reference to an `ArrayBase` - but let's keep `S: Data`, such /// that we are allowed to change the _layout_ of the array, but not its data. -fn takes_base_mut(arr: &mut ArrayBase) -{ +fn takes_base_mut(arr: &mut ArrayBase) { // Of course we can call everything we did with a immutable reference: takes_base(arr); @@ -108,8 +104,7 @@ fn takes_base_mut(arr: &mut ArrayBase) /// /// Note that we require a constraint of `D: Dimension` to dereference to `&mut ArrayRef`. #[allow(dead_code)] -fn takes_base_data_mut(arr: &mut ArrayBase) -{ +fn takes_base_data_mut(arr: &mut ArrayBase) { // Of course, everything we can do with just `S: Data`: takes_base_mut(arr); @@ -143,7 +138,8 @@ fn takes_layout_mut(_arr: &mut LayoutRef) {} /// without having to call `.as_ref` or `.as_layout_ref`. #[allow(dead_code)] fn takes_layout_asref(_arr: &T) -where T: AsRef> + ?Sized +where + T: AsRef> + ?Sized, { } @@ -153,7 +149,8 @@ where T: AsRef> + ?Sized /// `&mut ArcArray --(unshare)--> &mut ArrayRef -> &mut RawRef -> &mut LayoutRef`. #[allow(dead_code)] fn takes_layout_asmut(_arr: &mut T) -where T: AsMut> + ?Sized +where + T: AsMut> + ?Sized, { } @@ -167,16 +164,14 @@ where T: AsMut> + ?Sized /// Like `LayoutRef`, writing functions with `RawRef` can be done in a few ways. /// We start with a direct, immutable reference #[allow(dead_code)] -fn takes_rawref(arr: &RawRef) -{ +fn takes_rawref(arr: &RawRef) { takes_layout(arr); takes_layout_asref(arr); } /// We can also directly take a mutable reference. #[allow(dead_code)] -fn takes_rawref_mut(arr: &mut RawRef) -{ +fn takes_rawref_mut(arr: &mut RawRef) { takes_layout(arr); takes_layout_asmut(arr); } @@ -185,7 +180,8 @@ fn takes_rawref_mut(arr: &mut RawRef) /// for the same reasons as for `LayoutRef`: #[allow(dead_code)] fn takes_rawref_asref(_arr: &T) -where T: AsRef> + ?Sized +where + T: AsRef> + ?Sized, { takes_layout(_arr.as_ref()); takes_layout_asref(_arr.as_ref()); @@ -194,7 +190,8 @@ where T: AsRef> + ?Sized /// Finally, mutably: #[allow(dead_code)] fn takes_rawref_asmut(_arr: &mut T) -where T: AsMut> + ?Sized +where + T: AsMut> + ?Sized, { takes_layout_mut(_arr.as_mut()); takes_layout_asmut(_arr.as_mut()); diff --git a/examples/life.rs b/examples/life.rs index a521f34c..d430858b 100644 --- a/examples/life.rs +++ b/examples/life.rs @@ -8,8 +8,7 @@ const N: usize = 100; type Board = Array2; -fn parse(x: &[u8]) -> Board -{ +fn parse(x: &[u8]) -> Board { // make a border of 0 cells let mut map = Board::from_elem(((N + 2), (N + 2)), 0); let a = Array::from_iter(x.iter().filter_map(|&b| match b { @@ -29,8 +28,7 @@ fn parse(x: &[u8]) -> Board // 3 neighbors: birth // otherwise: death -fn iterate(z: &mut Board, scratch: &mut Board) -{ +fn iterate(z: &mut Board, scratch: &mut Board) { // compute number of neighbors let mut neigh = scratch.view_mut(); neigh.fill(0); @@ -53,8 +51,7 @@ fn iterate(z: &mut Board, scratch: &mut Board) zv.zip_mut_with(&neigh, |y, &n| *y = ((n == 3) || (n == 2 && *y > 0)) as u8); } -fn turn_on_corners(z: &mut Board) -{ +fn turn_on_corners(z: &mut Board) { let n = z.nrows(); let m = z.ncols(); z[[1, 1]] = 1; @@ -63,8 +60,7 @@ fn turn_on_corners(z: &mut Board) z[[n - 2, m - 2]] = 1; } -fn render(a: &Board) -{ +fn render(a: &Board) { for row in a.rows() { for &x in row { if x > 0 { @@ -77,8 +73,7 @@ fn render(a: &Board) } } -fn main() -{ +fn main() { let mut a = parse(INPUT); let mut scratch = Board::zeros((N, N)); let steps = 100; diff --git a/examples/ocr_benchmark.rs b/examples/ocr_benchmark.rs index fe74b019..6062273d 100644 --- a/examples/ocr_benchmark.rs +++ b/examples/ocr_benchmark.rs @@ -14,17 +14,9 @@ fn main() { eprintln!(" OCR Benchmark: ndarray SIMD vs tesseract"); eprintln!("═══════════════════════════════════════════════════════════\n"); - let pages = vec![ - "/tmp/ocr_bench/page-01.raw", - "/tmp/ocr_bench/page-02.raw", - "/tmp/ocr_bench/page-03.raw", - ]; - - let png_pages = vec![ - "/tmp/ocr_bench/page-01.png", - "/tmp/ocr_bench/page-02.png", - "/tmp/ocr_bench/page-03.png", - ]; + let pages = vec!["/tmp/ocr_bench/page-01.raw", "/tmp/ocr_bench/page-02.raw", "/tmp/ocr_bench/page-03.raw"]; + + let png_pages = vec!["/tmp/ocr_bench/page-01.png", "/tmp/ocr_bench/page-02.png", "/tmp/ocr_bench/page-03.png"]; // ── ndarray SIMD preprocessing ──────────────────────────────────── eprintln!("=== ndarray SIMD preprocessing ===\n"); @@ -33,18 +25,26 @@ fn main() { for (i, path) in pages.iter().enumerate() { let data = match std::fs::read(path) { Ok(d) => d, - Err(e) => { eprintln!(" skip {}: {}", path, e); continue; } + Err(e) => { + eprintln!(" skip {}: {}", path, e); + continue; + } }; - if data.len() < 8 { continue; } + if data.len() < 8 { + continue; + } let width = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize; let height = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize; let pixels = &data[8..]; - eprintln!(" Page {}: {}×{} ({:.1}M pixels)", i + 1, width, height, - (width * height) as f64 / 1_000_000.0); + eprintln!(" Page {}: {}×{} ({:.1}M pixels)", i + 1, width, height, (width * height) as f64 / 1_000_000.0); - let img = GrayImage { data: pixels, width, height }; + let img = GrayImage { + data: pixels, + width, + height, + }; // Warm up let _ = otsu_threshold(&img); @@ -113,8 +113,7 @@ fn main() { Ok(o) if o.status.success() => { let text = String::from_utf8_lossy(&o.stdout); let words = text.split_whitespace().count(); - eprintln!(" Page {}: {} words, {:.3}ms", - i + 1, words, elapsed.as_secs_f64() * 1000.0); + eprintln!(" Page {}: {} words, {:.3}ms", i + 1, words, elapsed.as_secs_f64() * 1000.0); // Show first 100 chars let preview: String = text.chars().take(100).collect(); eprintln!(" Preview: {}", preview.replace('\n', " ")); diff --git a/examples/rollaxis.rs b/examples/rollaxis.rs index 82c38129..8efdd0ce 100644 --- a/examples/rollaxis.rs +++ b/examples/rollaxis.rs @@ -22,8 +22,7 @@ where a } -fn main() -{ +fn main() { let mut data = array![ [[-1., 0., -2.], [1., 7., -3.]], [[1., 0., -3.], [1., 7., 5.]], diff --git a/examples/sort-axis.rs b/examples/sort-axis.rs index 112abfc7..ff5e7de3 100644 --- a/examples/sort-axis.rs +++ b/examples/sort-axis.rs @@ -12,16 +12,13 @@ use std::ptr::copy_nonoverlapping; // Type invariant: Each index appears exactly once #[derive(Clone, Debug)] -pub struct Permutation -{ +pub struct Permutation { indices: Vec, } -impl Permutation -{ +impl Permutation { /// Checks if the permutation is correct - pub fn from_indices(v: Vec) -> Result - { + pub fn from_indices(v: Vec) -> Result { let perm = Permutation { indices: v }; if perm.correct() { Ok(perm) @@ -30,35 +27,34 @@ impl Permutation } } - fn correct(&self) -> bool - { + fn correct(&self) -> bool { let axis_len = self.indices.len(); let mut seen = vec![false; axis_len]; for &i in &self.indices { match seen.get_mut(i) { None => return false, - Some(s) => + Some(s) => { if *s { return false; } else { *s = true; - }, + } + } } } true } } -pub trait SortArray -{ +pub trait SortArray { /// ***Panics*** if `axis` is out of bounds. fn identity(&self, axis: Axis) -> Permutation; fn sort_axis_by(&self, axis: Axis, less_than: F) -> Permutation - where F: FnMut(usize, usize) -> bool; + where + F: FnMut(usize, usize) -> bool; } -pub trait PermuteArray -{ +pub trait PermuteArray { type Elem; type Dim; fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array @@ -72,15 +68,15 @@ where S: Data, D: Dimension, { - fn identity(&self, axis: Axis) -> Permutation - { + fn identity(&self, axis: Axis) -> Permutation { Permutation { indices: (0..self.len_of(axis)).collect(), } } fn sort_axis_by(&self, axis: Axis, mut less_than: F) -> Permutation - where F: FnMut(usize, usize) -> bool + where + F: FnMut(usize, usize) -> bool, { let mut perm = self.identity(axis); perm.indices.sort_by(move |&a, &b| { @@ -97,13 +93,15 @@ where } impl PermuteArray for Array -where D: Dimension +where + D: Dimension, { type Elem = A; type Dim = D; fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array - where D: RemoveAxis + where + D: RemoveAxis, { let axis_len = self.len_of(axis); let axis_stride = self.stride_of(axis); @@ -167,8 +165,7 @@ where D: Dimension } #[cfg(feature = "std")] -fn main() -{ +fn main() { let a = Array::linspace(0.0..=63.0, 64) .into_shape_with_order((8, 8)) .unwrap(); @@ -188,12 +185,10 @@ fn main() fn main() {} #[cfg(test)] -mod tests -{ +mod tests { use super::*; #[test] - fn test_permute_axis() - { + fn test_permute_axis() { let a = array![ [107998.96, 1.], [107999.08, 2.], diff --git a/examples/type_conversion.rs b/examples/type_conversion.rs index 722991d4..4e403576 100644 --- a/examples/type_conversion.rs +++ b/examples/type_conversion.rs @@ -7,8 +7,7 @@ use approx::assert_abs_diff_eq; use ndarray::prelude::*; #[cfg(feature = "approx")] -fn main() -{ +fn main() { // Converting an array from one datatype to another is implemented with the // `ArrayBase::mapv()` function. We pass a closure that is applied to each // element independently. This allows for more control and flexibility in diff --git a/examples/zip_many.rs b/examples/zip_many.rs index 94419413..bdaee556 100644 --- a/examples/zip_many.rs +++ b/examples/zip_many.rs @@ -3,8 +3,7 @@ use ndarray::prelude::*; use ndarray::Zip; -fn main() -{ +fn main() { let n = 6; let mut a = Array::::zeros((n, n)); diff --git a/ndarray-rand/benches/bench.rs b/ndarray-rand/benches/bench.rs index 364eca9f..e28eb620 100644 --- a/ndarray-rand/benches/bench.rs +++ b/ndarray-rand/benches/bench.rs @@ -10,22 +10,19 @@ use rand_distr::Uniform; use test::Bencher; #[bench] -fn uniform_f32(b: &mut Bencher) -{ +fn uniform_f32(b: &mut Bencher) { let m = 100; b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.).unwrap())); } #[bench] -fn norm_f32(b: &mut Bencher) -{ +fn norm_f32(b: &mut Bencher) { let m = 100; b.iter(|| Array::random((m, m), Normal::new(0f32, 1.).unwrap())); } #[bench] -fn norm_f64(b: &mut Bencher) -{ +fn norm_f64(b: &mut Bencher) { let m = 100; b.iter(|| Array::random((m, m), Normal::new(0f64, 1.).unwrap())); } diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index d155695a..15945075 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -42,14 +42,12 @@ use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData}; use quickcheck::{Arbitrary, Gen}; /// `rand`, re-exported for convenience and version-compatibility. -pub mod rand -{ +pub mod rand { pub use rand::*; } /// `rand-distr`, re-exported for convenience and version-compatibility. -pub mod rand_distr -{ +pub mod rand_distr { pub use rand_distr::*; } @@ -155,7 +153,8 @@ where /// documentation for information. You can select a different RNG with /// [`.sample_axis_using()`](RandomRefExt::sample_axis_using). pub trait RandomRefExt -where D: Dimension +where + D: Dimension, { /// Sample `n_samples` lanes slicing along `axis` using the default RNG. /// @@ -305,7 +304,8 @@ where } impl RandomRefExt for ArrayRef -where D: Dimension +where + D: Dimension, { fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array where @@ -340,18 +340,15 @@ where D: Dimension /// [`sample_axis_using`]: RandomRefExt::sample_axis_using #[derive(Debug, Clone)] #[allow(missing_docs)] -pub enum SamplingStrategy -{ +pub enum SamplingStrategy { WithReplacement, WithoutReplacement, } // `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing. #[cfg(feature = "quickcheck")] -impl Arbitrary for SamplingStrategy -{ - fn arbitrary(g: &mut Gen) -> Self - { +impl Arbitrary for SamplingStrategy { + fn arbitrary(g: &mut Gen) -> Self { if bool::arbitrary(g) { SamplingStrategy::WithReplacement } else { @@ -360,7 +357,6 @@ impl Arbitrary for SamplingStrategy } } -fn get_rng() -> SmallRng -{ +fn get_rng() -> SmallRng { SmallRng::from_rng(&mut rng()) } diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index b1a80e5e..fcd28fb1 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -8,8 +8,7 @@ use ndarray_rand::{RandomExt, SamplingStrategy}; use quickcheck::{quickcheck, TestResult}; #[test] -fn test_dim() -{ +fn test_dim() { let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { @@ -23,8 +22,7 @@ fn test_dim() } #[test] -fn test_dim_f() -{ +fn test_dim_f() { let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { @@ -38,8 +36,7 @@ fn test_dim_f() } #[test] -fn sample_axis_on_view() -{ +fn sample_axis_on_view() { let m = 5; let a = Array::random((m, 4), Uniform::new(0., 2.).unwrap()); let _samples = a @@ -49,8 +46,7 @@ fn sample_axis_on_view() #[test] #[should_panic] -fn oversampling_without_replacement_should_panic() -{ +fn oversampling_without_replacement_should_panic() { let m = 5; let a = Array::random((m, 4), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement); @@ -117,8 +113,7 @@ quickcheck! { } } -fn sampling_works(a: &Array2, strategy: SamplingStrategy, axis: Axis, n_samples: usize) -> bool -{ +fn sampling_works(a: &Array2, strategy: SamplingStrategy, axis: Axis, n_samples: usize) -> bool { let samples = a.sample_axis(axis, n_samples, strategy); samples .axis_iter(axis) @@ -126,15 +121,13 @@ fn sampling_works(a: &Array2, strategy: SamplingStrategy, axis: Axis, n_sam } // Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b` -fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool -{ +fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool { a.axis_iter(axis).any(|lane| lane == b) } #[test] #[should_panic] -fn sampling_without_replacement_from_a_zero_length_axis_should_panic() -{ +fn sampling_without_replacement_from_a_zero_length_axis_should_panic() { let n = 5; let a = Array::random((0, n), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement); @@ -142,8 +135,7 @@ fn sampling_without_replacement_from_a_zero_length_axis_should_panic() #[test] #[should_panic] -fn sampling_with_replacement_from_a_zero_length_axis_should_panic() -{ +fn sampling_with_replacement_from_a_zero_length_axis_should_panic() { let n = 5; let a = Array::random((0, n), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement); diff --git a/rustfmt.toml b/rustfmt.toml index f3e376cc..a1905f86 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,26 +1,10 @@ -edition = "2018" +edition = "2021" +max_width = 120 array_width = 100 chain_width = 60 fn_call_width = 100 -max_width = 120 -brace_style = "AlwaysNextLine" -control_brace_style = "AlwaysSameLine" -fn_params_layout = "Compressed" # ? -format_macro_bodies = false -imports_granularity = "Preserve" -imports_indent = "Block" -imports_layout = "HorizontalVertical" -inline_attribute_width = 0 -indent_style = "Block" -match_arm_blocks = false +fn_params_layout = "Compressed" match_arm_leading_pipes = "Preserve" merge_derives = false -overflow_delimited_expr = true -reorder_modules = false # impacts rustdoc order +reorder_modules = false short_array_element_width_threshold = 32 -skip_macro_invocations = ["*"] -unstable_features = true -where_single_line = true - -# ignored files -ignore = [] diff --git a/src/alias_asref.rs b/src/alias_asref.rs index ab78af60..6c4ad74f 100644 --- a/src/alias_asref.rs +++ b/src/alias_asref.rs @@ -1,20 +1,10 @@ use crate::{ - iter::Axes, - ArrayBase, - Axis, - AxisDescription, - Dimension, - NdIndex, - RawArrayView, - RawData, - RawDataMut, - Slice, + iter::Axes, ArrayBase, Axis, AxisDescription, Dimension, NdIndex, RawArrayView, RawData, RawDataMut, Slice, SliceArg, }; /// Functions coming from RawRef -impl, D: Dimension> ArrayBase -{ +impl, D: Dimension> ArrayBase { /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -29,7 +19,8 @@ impl, D: Dimension> ArrayBase /// assert_eq!(unsafe { *p }, 2.); /// ``` pub fn get_ptr(&self, index: I) -> Option<*const A> - where I: NdIndex + where + I: NdIndex, { self.as_raw_ref().get_ptr(index) } @@ -69,22 +60,19 @@ impl, D: Dimension> ArrayBase /// /// where *d* is `self.ndim()`. #[inline(always)] - pub fn as_ptr(&self) -> *const A - { + pub fn as_ptr(&self) -> *const A { self.as_raw_ref().as_ptr() } /// Return a raw view of the array. #[inline] - pub fn raw_view(&self) -> RawArrayView - { + pub fn raw_view(&self) -> RawArrayView { self.as_raw_ref().raw_view() } } /// Functions coming from LayoutRef -impl ArrayBase -{ +impl ArrayBase { /// Slice the array in place without changing the number of dimensions. /// /// In particular, if an axis is sliced with an index, the axis is @@ -107,7 +95,8 @@ impl ArrayBase /// - if `D` is `IxDyn` and `info` does not match the number of array axes #[track_caller] pub fn slice_collapse(&mut self, info: I) - where I: SliceArg + where + I: SliceArg, { self.as_layout_ref_mut().slice_collapse(info); } @@ -117,8 +106,7 @@ impl ArrayBase /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. #[track_caller] - pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) - { + pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) { self.as_layout_ref_mut().slice_axis_inplace(axis, indices); } @@ -131,7 +119,8 @@ impl ArrayBase /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis_inplace(&mut self, f: F) - where F: FnMut(AxisDescription) -> Slice + where + F: FnMut(AxisDescription) -> Slice, { self.as_layout_ref_mut().slice_each_axis_inplace(f); } @@ -140,8 +129,7 @@ impl ArrayBase /// /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] - pub fn collapse_axis(&mut self, axis: Axis, index: usize) - { + pub fn collapse_axis(&mut self, axis: Axis, index: usize) { self.as_layout_ref_mut().collapse_axis(axis, index); } @@ -150,20 +138,17 @@ impl ArrayBase /// /// Return `false` otherwise, i.e. the array is possibly not /// contiguous in memory, it has custom strides, etc. - pub fn is_standard_layout(&self) -> bool - { + pub fn is_standard_layout(&self) -> bool { self.as_layout_ref().is_standard_layout() } /// Return true if the array is known to be contiguous. - pub(crate) fn is_contiguous(&self) -> bool - { + pub(crate) fn is_contiguous(&self) -> bool { self.as_layout_ref().is_contiguous() } /// Return an iterator over the length and stride of each axis. - pub fn axes(&self) -> Axes<'_, D> - { + pub fn axes(&self) -> Axes<'_, D> { self.as_layout_ref().axes() } @@ -176,8 +161,7 @@ impl ArrayBase /// Return the axis with the greatest stride (by absolute value), /// preferring axes with len > 1. - pub fn max_stride_axis(&self) -> Axis - { + pub fn max_stride_axis(&self) -> Axis { self.as_layout_ref().max_stride_axis() } @@ -185,8 +169,7 @@ impl ArrayBase /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn invert_axis(&mut self, axis: Axis) - { + pub fn invert_axis(&mut self, axis: Axis) { self.as_layout_ref_mut().invert_axis(axis); } @@ -207,8 +190,7 @@ impl ArrayBase /// ); /// ``` #[track_caller] - pub fn swap_axes(&mut self, ax: usize, bx: usize) - { + pub fn swap_axes(&mut self, ax: usize, bx: usize) { self.as_layout_ref_mut().swap_axes(ax, bx); } @@ -248,14 +230,12 @@ impl ArrayBase /// /// ***Panics*** if an axis is out of bounds. #[track_caller] - pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool - { + pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool { self.as_layout_ref_mut().merge_axes(take, into) } /// Return the total number of elements in the array. - pub fn len(&self) -> usize - { + pub fn len(&self) -> usize { self.as_layout_ref().len() } @@ -266,28 +246,24 @@ impl ArrayBase /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn len_of(&self, axis: Axis) -> usize - { + pub fn len_of(&self, axis: Axis) -> usize { self.as_layout_ref().len_of(axis) } /// Return whether the array has any elements - pub fn is_empty(&self) -> bool - { + pub fn is_empty(&self) -> bool { self.as_layout_ref().is_empty() } /// Return the number of dimensions (axes) in the array - pub fn ndim(&self) -> usize - { + pub fn ndim(&self) -> usize { self.as_layout_ref().ndim() } /// Return the shape of the array in its “pattern” form, /// an integer in the one-dimensional case, tuple in the n-dimensional cases /// and so on. - pub fn dim(&self) -> D::Pattern - { + pub fn dim(&self) -> D::Pattern { self.as_layout_ref().dim() } @@ -305,8 +281,7 @@ impl ArrayBase /// // Create an array of zeros that's the same shape and dimensionality as `a`. /// let b = Array::::zeros(a.raw_dim()); /// ``` - pub fn raw_dim(&self) -> D - { + pub fn raw_dim(&self) -> D { self.as_layout_ref().raw_dim() } @@ -334,14 +309,12 @@ impl ArrayBase /// let c = Array::zeros(a.raw_dim()); /// assert_eq!(a, c); /// ``` - pub fn shape(&self) -> &[usize] - { + pub fn shape(&self) -> &[usize] { self.as_layout_ref().shape() } /// Return the strides of the array as a slice. - pub fn strides(&self) -> &[isize] - { + pub fn strides(&self) -> &[isize] { self.as_layout_ref().strides() } @@ -352,8 +325,7 @@ impl ArrayBase /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn stride_of(&self, axis: Axis) -> isize - { + pub fn stride_of(&self, axis: Axis) -> isize { self.as_layout_ref().stride_of(axis) } } diff --git a/src/aliases.rs b/src/aliases.rs index 7dc6fe8e..dbe742f7 100644 --- a/src/aliases.rs +++ b/src/aliases.rs @@ -7,58 +7,50 @@ use crate::{ArcArray, Array, ArrayRef, ArrayView, ArrayViewMut, Ix, IxDynImpl, L /// Create a zero-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix0() -> Ix0 -{ +pub const fn Ix0() -> Ix0 { Dim::new([]) } /// Create a one-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix1(i0: Ix) -> Ix1 -{ +pub const fn Ix1(i0: Ix) -> Ix1 { Dim::new([i0]) } /// Create a two-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix2(i0: Ix, i1: Ix) -> Ix2 -{ +pub const fn Ix2(i0: Ix, i1: Ix) -> Ix2 { Dim::new([i0, i1]) } /// Create a three-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix3(i0: Ix, i1: Ix, i2: Ix) -> Ix3 -{ +pub const fn Ix3(i0: Ix, i1: Ix, i2: Ix) -> Ix3 { Dim::new([i0, i1, i2]) } /// Create a four-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix4(i0: Ix, i1: Ix, i2: Ix, i3: Ix) -> Ix4 -{ +pub const fn Ix4(i0: Ix, i1: Ix, i2: Ix, i3: Ix) -> Ix4 { Dim::new([i0, i1, i2, i3]) } /// Create a five-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix5(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix) -> Ix5 -{ +pub const fn Ix5(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix) -> Ix5 { Dim::new([i0, i1, i2, i3, i4]) } /// Create a six-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub const fn Ix6(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix, i5: Ix) -> Ix6 -{ +pub const fn Ix6(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix, i5: Ix) -> Ix6 { Dim::new([i0, i1, i2, i3, i4, i5]) } /// Create a dynamic-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn IxDyn(ix: &[Ix]) -> IxDyn -{ +pub fn IxDyn(ix: &[Ix]) -> IxDyn { Dim(ix) } diff --git a/src/argument_traits.rs b/src/argument_traits.rs index c4e85186..7a4694c4 100644 --- a/src/argument_traits.rs +++ b/src/argument_traits.rs @@ -4,45 +4,36 @@ use std::mem::MaybeUninit; use crate::math_cell::MathCell; /// A producer element that can be assigned to once -pub trait AssignElem -{ +pub trait AssignElem { /// Assign the value `input` to the element that self represents. fn assign_elem(self, input: T); } /// Assignable element, simply `*self = input`. -impl AssignElem for &mut T -{ - fn assign_elem(self, input: T) - { +impl AssignElem for &mut T { + fn assign_elem(self, input: T) { *self = input; } } /// Assignable element, simply `self.set(input)`. -impl AssignElem for &Cell -{ - fn assign_elem(self, input: T) - { +impl AssignElem for &Cell { + fn assign_elem(self, input: T) { self.set(input); } } /// Assignable element, simply `self.set(input)`. -impl AssignElem for &MathCell -{ - fn assign_elem(self, input: T) - { +impl AssignElem for &MathCell { + fn assign_elem(self, input: T) { self.set(input); } } /// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not /// read or dropped). -impl AssignElem for &mut MaybeUninit -{ - fn assign_elem(self, input: T) - { +impl AssignElem for &mut MaybeUninit { + fn assign_elem(self, input: T) { *self = MaybeUninit::new(input); } } diff --git a/src/array_approx.rs b/src/array_approx.rs index 93d1bd0a..78d41966 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -1,10 +1,8 @@ #[cfg(feature = "approx")] -mod approx_methods -{ +mod approx_methods { use crate::imp_prelude::*; - impl ArrayRef - { + impl ArrayRef { /// A test for equality that uses the elementwise absolute difference to compute the /// approximate equality of two arrays. pub fn abs_diff_eq(&self, other: &ArrayRef, epsilon: A::Epsilon) -> bool @@ -89,19 +87,14 @@ macro_rules! impl_approx_traits { A::default_max_relative() } - fn relative_eq( - &self, - other: &ArrayRef, - epsilon: A::Epsilon, - max_relative: A::Epsilon, - ) -> bool { + fn relative_eq(&self, other: &ArrayRef, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool { if self.shape() != other.shape() { return false; } - Zip::from(self).and(other).all(move |a, b| { - A::relative_eq(a, b, epsilon.clone(), max_relative.clone()) - }) + Zip::from(self) + .and(other) + .all(move |a, b| A::relative_eq(a, b, epsilon.clone(), max_relative.clone())) } } @@ -118,12 +111,7 @@ macro_rules! impl_approx_traits { A::default_max_relative() } - fn relative_eq( - &self, - other: &ArrayBase, - epsilon: A::Epsilon, - max_relative: A::Epsilon, - ) -> bool { + fn relative_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool { (**self).relative_eq(other, epsilon, max_relative) } } @@ -139,12 +127,7 @@ macro_rules! impl_approx_traits { A::default_max_ulps() } - fn ulps_eq( - &self, - other: &ArrayRef, - epsilon: A::Epsilon, - max_ulps: u32, - ) -> bool { + fn ulps_eq(&self, other: &ArrayRef, epsilon: A::Epsilon, max_ulps: u32) -> bool { if self.shape() != other.shape() { return false; } @@ -168,12 +151,7 @@ macro_rules! impl_approx_traits { A::default_max_ulps() } - fn ulps_eq( - &self, - other: &ArrayBase, - epsilon: A::Epsilon, - max_ulps: u32, - ) -> bool { + fn ulps_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_ulps: u32) -> bool { (**self).ulps_eq(other, epsilon, max_ulps) } } @@ -183,8 +161,8 @@ macro_rules! impl_approx_traits { use crate::prelude::*; use alloc::vec; use $approx::{ - assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne, - assert_ulps_eq, assert_ulps_ne, + assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne, assert_ulps_eq, + assert_ulps_ne, }; #[test] diff --git a/src/array_serde.rs b/src/array_serde.rs index 5d51a801..9f46d3d0 100644 --- a/src/array_serde.rs +++ b/src/array_serde.rs @@ -24,7 +24,8 @@ use crate::IntoDimension; /// Verifies that the version of the deserialized array matches the current /// `ARRAY_FORMAT_VERSION`. pub fn verify_version(v: u8) -> Result<(), E> -where E: de::Error +where + E: de::Error, { if v != ARRAY_FORMAT_VERSION { let err_msg = format!("unknown array version: {}", v); @@ -36,10 +37,12 @@ where E: de::Error /// **Requires crate feature `"serde"`** impl Serialize for Dim -where I: Serialize +where + I: Serialize, { fn serialize(&self, serializer: Se) -> Result - where Se: Serializer + where + Se: Serializer, { self.ix().serialize(serializer) } @@ -47,30 +50,32 @@ where I: Serialize /// **Requires crate feature `"serde"`** impl<'de, I> Deserialize<'de> for Dim -where I: Deserialize<'de> +where + I: Deserialize<'de>, { fn deserialize(deserializer: D) -> Result - where D: Deserializer<'de> + where + D: Deserializer<'de>, { I::deserialize(deserializer).map(Dim::new) } } /// **Requires crate feature `"serde"`** -impl Serialize for IxDyn -{ +impl Serialize for IxDyn { fn serialize(&self, serializer: Se) -> Result - where Se: Serializer + where + Se: Serializer, { self.ix().serialize(serializer) } } /// **Requires crate feature `"serde"`** -impl<'de> Deserialize<'de> for IxDyn -{ +impl<'de> Deserialize<'de> for IxDyn { fn deserialize(deserializer: D) -> Result - where D: Deserializer<'de> + where + D: Deserializer<'de>, { let v = Vec::::deserialize(deserializer)?; Ok(v.into_dimension()) @@ -85,7 +90,8 @@ where S: Data, { fn serialize(&self, serializer: Se) -> Result - where Se: Serializer + where + Se: Serializer, { let mut state = serializer.serialize_struct("Array", 3)?; state.serialize_field("v", &ARRAY_FORMAT_VERSION)?; @@ -104,7 +110,8 @@ where D: Dimension + Serialize, { fn serialize(&self, serializer: S) -> Result - where S: Serializer + where + S: Serializer, { let iter = &self.0; let mut seq = serializer.serialize_seq(Some(iter.len()))?; @@ -115,23 +122,19 @@ where } } -struct ArrayVisitor -{ +struct ArrayVisitor { _marker_a: PhantomData, _marker_b: PhantomData, } -enum ArrayField -{ +enum ArrayField { Version, Dim, Data, } -impl ArrayVisitor -{ - pub fn new() -> Self - { +impl ArrayVisitor { + pub fn new() -> Self { ArrayVisitor { _marker_a: PhantomData, _marker_b: PhantomData, @@ -149,30 +152,30 @@ where S: DataOwned, { fn deserialize(deserializer: D) -> Result, D::Error> - where D: Deserializer<'de> + where + D: Deserializer<'de>, { deserializer.deserialize_struct("Array", ARRAY_FIELDS, ArrayVisitor::new()) } } -impl<'de> Deserialize<'de> for ArrayField -{ +impl<'de> Deserialize<'de> for ArrayField { fn deserialize(deserializer: D) -> Result - where D: Deserializer<'de> + where + D: Deserializer<'de>, { struct ArrayFieldVisitor; - impl Visitor<'_> for ArrayFieldVisitor - { + impl Visitor<'_> for ArrayFieldVisitor { type Value = ArrayField; - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str(r#""v", "dim", or "data""#) } fn visit_str(self, value: &str) -> Result - where E: de::Error + where + E: de::Error, { match value { "v" => Ok(ArrayField::Version), @@ -183,7 +186,8 @@ impl<'de> Deserialize<'de> for ArrayField } fn visit_bytes(self, value: &[u8]) -> Result - where E: de::Error + where + E: de::Error, { match value { b"v" => Ok(ArrayField::Version), @@ -206,13 +210,13 @@ where { type Value = ArrayBase; - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("ndarray representation") } fn visit_seq(self, mut visitor: V) -> Result, V::Error> - where V: SeqAccess<'de> + where + V: SeqAccess<'de>, { let v: u8 = match visitor.next_element()? { Some(value) => value, @@ -245,7 +249,8 @@ where } fn visit_map(self, mut visitor: V) -> Result, V::Error> - where V: MapAccess<'de> + where + V: MapAccess<'de>, { let mut v: Option = None; let mut data: Option> = None; diff --git a/src/arrayformat.rs b/src/arrayformat.rs index 7e5e1b1c..169e49b9 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -32,17 +32,14 @@ const AXIS_2D_OVERFLOW_LIMIT: usize = 22; const ELLIPSIS: &str = "..."; #[derive(Clone, Debug)] -struct FormatOptions -{ +struct FormatOptions { axis_collapse_limit: usize, axis_collapse_limit_next_last: usize, axis_collapse_limit_last: usize, } -impl FormatOptions -{ - pub(crate) fn default_for_array(nelem: usize, no_limit: bool) -> Self - { +impl FormatOptions { + pub(crate) fn default_for_array(nelem: usize, no_limit: bool) -> Self { let default = Self { axis_collapse_limit: AXIS_LIMIT_STACKED, axis_collapse_limit_next_last: AXIS_LIMIT_COL, @@ -51,8 +48,7 @@ impl FormatOptions default.set_no_limit(no_limit || nelem < ARRAY_MANY_ELEMENT_LIMIT) } - fn set_no_limit(mut self, no_limit: bool) -> Self - { + fn set_no_limit(mut self, no_limit: bool) -> Self { if no_limit { self.axis_collapse_limit = usize::MAX; self.axis_collapse_limit_next_last = usize::MAX; @@ -63,8 +59,7 @@ impl FormatOptions /// Axis length collapse limit before ellipsizing, where `axis_rindex` is /// the index of the axis from the back. - pub(crate) fn collapse_limit(&self, axis_rindex: usize) -> usize - { + pub(crate) fn collapse_limit(&self, axis_rindex: usize) -> usize { match axis_rindex { 0 => self.axis_collapse_limit_last, 1 => self.axis_collapse_limit_next_last, @@ -88,8 +83,7 @@ impl FormatOptions fn format_with_overflow( f: &mut fmt::Formatter<'_>, length: usize, limit: usize, separator: &str, ellipsis: &str, fmt_elem: &mut dyn FnMut(&mut fmt::Formatter, usize) -> fmt::Result, -) -> fmt::Result -{ +) -> fmt::Result { if length == 0 { // no-op } else if length <= limit { @@ -175,10 +169,10 @@ where /// /// The array is shown in multiline style. impl fmt::Display for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -187,10 +181,8 @@ where S: Data /// used to each element. /// /// The array is shown in multiline style. -impl fmt::Display for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Display for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -201,10 +193,10 @@ impl fmt::Display for ArrayRef /// /// The array is shown in multiline style. impl fmt::Debug for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -213,21 +205,13 @@ where S: Data /// to each element. /// /// The array is shown in multiline style. -impl fmt::Debug for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Debug for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt)?; // Add extra information for Debug - write!( - f, - ", shape={:?}, strides={:?}, layout={:?}", - self.shape(), - self.strides(), - self.view().layout(), - )?; + write!(f, ", shape={:?}, strides={:?}, layout={:?}", self.shape(), self.strides(), self.view().layout(),)?; match D::NDIM { Some(ndim) => write!(f, ", const ndim={}", ndim)?, None => write!(f, ", dynamic ndim={}", self.ndim())?, @@ -241,10 +225,10 @@ impl fmt::Debug for ArrayRef /// /// The array is shown in multiline style. impl fmt::LowerExp for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -253,10 +237,8 @@ where S: Data /// to each element. /// /// The array is shown in multiline style. -impl fmt::LowerExp for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::LowerExp for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -267,10 +249,10 @@ impl fmt::LowerExp for ArrayRef /// /// The array is shown in multiline style. impl fmt::UpperExp for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -279,10 +261,8 @@ where S: Data /// to each element. /// /// The array is shown in multiline style. -impl fmt::UpperExp for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::UpperExp for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -293,10 +273,10 @@ impl fmt::UpperExp for ArrayRef /// /// The array is shown in multiline style. impl fmt::LowerHex for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -305,10 +285,8 @@ where S: Data /// to each element. /// /// The array is shown in multiline style. -impl fmt::LowerHex for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::LowerHex for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -319,10 +297,10 @@ impl fmt::LowerHex for ArrayRef /// /// The array is shown in multiline style. impl fmt::Binary for ArrayBase -where S: Data +where + S: Data, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } @@ -331,18 +309,15 @@ where S: Data /// to each element. /// /// The array is shown in multiline style. -impl fmt::Binary for ArrayRef -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Binary for ArrayRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } } #[cfg(test)] -mod formatting_with_omit -{ +mod formatting_with_omit { #[cfg(not(feature = "std"))] use alloc::string::String; #[cfg(not(feature = "std"))] @@ -352,19 +327,12 @@ mod formatting_with_omit use super::*; use crate::prelude::*; - fn assert_str_eq(expected: &str, actual: &str) - { + fn assert_str_eq(expected: &str, actual: &str) { // use assert to avoid printing the strings twice on failure - assert!( - expected == actual, - "formatting assertion failed\nexpected:\n{}\nactual:\n{}\n", - expected, - actual, - ); + assert!(expected == actual, "formatting assertion failed\nexpected:\n{}\nactual:\n{}\n", expected, actual,); } - fn ellipsize(limit: usize, sep: &str, elements: impl IntoIterator) -> String - { + fn ellipsize(limit: usize, sep: &str, elements: impl IntoIterator) -> String { let elements = elements.into_iter().collect::>(); let edge = limit / 2; if elements.len() <= limit { @@ -382,8 +350,7 @@ mod formatting_with_omit } #[test] - fn empty_arrays() - { + fn empty_arrays() { let a: Array2 = arr2(&[[], []]); let actual = format!("{}", a); let expected = "[[]]"; @@ -391,8 +358,7 @@ mod formatting_with_omit } #[test] - fn zero_length_axes() - { + fn zero_length_axes() { let a = Array3::::zeros((3, 0, 4)); let actual = format!("{}", a); let expected = "[[[]]]"; @@ -400,8 +366,7 @@ mod formatting_with_omit } #[test] - fn dim_0() - { + fn dim_0() { let element = 12; let a = arr0(element); let actual = format!("{}", a); @@ -410,8 +375,7 @@ mod formatting_with_omit } #[test] - fn dim_1() - { + fn dim_1() { let overflow: usize = 2; let a = Array1::from_elem(ARRAY_MANY_ELEMENT_LIMIT + overflow, 1); let actual = format!("{}", a); @@ -420,8 +384,7 @@ mod formatting_with_omit } #[test] - fn dim_1_alternate() - { + fn dim_1_alternate() { let overflow: usize = 2; let a = Array1::from_elem(ARRAY_MANY_ELEMENT_LIMIT + overflow, 1); let actual = format!("{:#}", a); @@ -430,8 +393,7 @@ mod formatting_with_omit } #[test] - fn dim_2_last_axis_overflow() - { + fn dim_2_last_axis_overflow() { let overflow: usize = 2; let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{}", a); @@ -451,21 +413,16 @@ mod formatting_with_omit } #[test] - fn dim_2_non_last_axis_overflow() - { + fn dim_2_non_last_axis_overflow() { let a = Array2::from_elem((ARRAY_MANY_ELEMENT_LIMIT / 10, 10), 1); let actual = format!("{}", a); let row = format!("{}", a.row(0)); - let expected = format!( - "[{}]", - ellipsize(AXIS_LIMIT_COL, ",\n ", (0..a.nrows()).map(|_| &row)) - ); + let expected = format!("[{}]", ellipsize(AXIS_LIMIT_COL, ",\n ", (0..a.nrows()).map(|_| &row))); assert_str_eq(&expected, &actual); } #[test] - fn dim_2_non_last_axis_overflow_alternate() - { + fn dim_2_non_last_axis_overflow_alternate() { let a = Array2::from_elem((AXIS_LIMIT_COL * 4, 6), 1); let actual = format!("{:#}", a); let row = format!("{}", a.row(0)); @@ -474,22 +431,17 @@ mod formatting_with_omit } #[test] - fn dim_2_multi_directional_overflow() - { + fn dim_2_multi_directional_overflow() { let overflow: usize = 2; let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT + overflow, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{}", a); let row = format!("[{}]", ellipsize(AXIS_LIMIT_ROW, ", ", a.row(0))); - let expected = format!( - "[{}]", - ellipsize(AXIS_LIMIT_COL, ",\n ", (0..a.nrows()).map(|_| &row)) - ); + let expected = format!("[{}]", ellipsize(AXIS_LIMIT_COL, ",\n ", (0..a.nrows()).map(|_| &row))); assert_str_eq(&expected, &actual); } #[test] - fn dim_2_multi_directional_overflow_alternate() - { + fn dim_2_multi_directional_overflow_alternate() { let overflow: usize = 2; let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT + overflow, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{:#}", a); @@ -499,8 +451,7 @@ mod formatting_with_omit } #[test] - fn dim_3_overflow_most() - { + fn dim_3_overflow_most() { let a = Array3::from_shape_fn((AXIS_LIMIT_STACKED + 1, AXIS_LIMIT_COL, AXIS_LIMIT_ROW + 1), |(i, j, k)| { 1000. + (100. * ((i as f64).sqrt() + (j as f64).sin() + k as f64)).round() / 100. }); @@ -583,8 +534,7 @@ mod formatting_with_omit } #[test] - fn dim_4_overflow_outer() - { + fn dim_4_overflow_outer() { let a = Array4::from_shape_fn((10, 10, 3, 3), |(i, j, k, l)| i + j + k + l); let actual = format!("{:2}", a); // Generated using NumPy with: diff --git a/src/arraytraits.rs b/src/arraytraits.rs index 8b214ac9..f594f8ed 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -22,17 +22,12 @@ use crate::Arc; use crate::{ dimension, iter::{Iter, IterMut}, - numeric_util, - FoldWhile, - NdIndex, - OwnedArcRepr, - Zip, + numeric_util, FoldWhile, NdIndex, OwnedArcRepr, Zip, }; #[cold] #[inline(never)] -pub(crate) fn array_out_of_bounds() -> ! -{ +pub(crate) fn array_out_of_bounds() -> ! { panic!("ndarray: index out of bounds"); } @@ -58,8 +53,7 @@ where type Output = A; #[inline] - fn index(&self, index: I) -> &Self::Output - { + fn index(&self, index: I) -> &Self::Output { debug_bounds_check_ref!(self, index); unsafe { &*self._ptr().as_ptr().offset( @@ -80,8 +74,7 @@ where I: NdIndex, { #[inline] - fn index_mut(&mut self, index: I) -> &mut A - { + fn index_mut(&mut self, index: I) -> &mut A { debug_bounds_check_ref!(self, index); unsafe { &mut *self.as_mut_ptr().offset( @@ -105,8 +98,7 @@ where type Output = S::Elem; #[inline] - fn index(&self, index: I) -> &S::Elem - { + fn index(&self, index: I) -> &S::Elem { Index::index(&**self, index) } } @@ -121,8 +113,7 @@ where S: DataMut, { #[inline] - fn index_mut(&mut self, index: I) -> &mut S::Elem - { + fn index_mut(&mut self, index: I) -> &mut S::Elem { IndexMut::index_mut(&mut (**self), index) } } @@ -134,8 +125,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, rhs: &ArrayRef) -> bool - { + fn eq(&self, rhs: &ArrayRef) -> bool { if self.shape() != rhs.shape() { return false; } @@ -164,8 +154,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, rhs: &&ArrayRef) -> bool - { + fn eq(&self, rhs: &&ArrayRef) -> bool { *self == **rhs } } @@ -177,8 +166,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, rhs: &ArrayRef) -> bool - { + fn eq(&self, rhs: &ArrayRef) -> bool { **self == *rhs } } @@ -199,8 +187,7 @@ where S2: Data, D: Dimension, { - fn eq(&self, rhs: &ArrayBase) -> bool - { + fn eq(&self, rhs: &ArrayBase) -> bool { PartialEq::eq(&**self, &**rhs) } } @@ -214,8 +201,7 @@ where S2: Data, D: Dimension, { - fn eq(&self, rhs: &&ArrayBase) -> bool - { + fn eq(&self, rhs: &&ArrayBase) -> bool { *self == **rhs } } @@ -229,8 +215,7 @@ where S2: Data, D: Dimension, { - fn eq(&self, rhs: &ArrayBase) -> bool - { + fn eq(&self, rhs: &ArrayBase) -> bool { **self == *rhs } } @@ -248,8 +233,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &ArrayRef) -> bool - { + fn eq(&self, other: &ArrayRef) -> bool { **self == other } } @@ -260,8 +244,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &&ArrayRef) -> bool - { + fn eq(&self, other: &&ArrayRef) -> bool { **self == *other } } @@ -272,8 +255,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &ArrayRef) -> bool - { + fn eq(&self, other: &ArrayRef) -> bool { **self == other } } @@ -284,8 +266,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &ArrayBase) -> bool - { + fn eq(&self, other: &ArrayBase) -> bool { self == **other } } @@ -296,8 +277,7 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &&ArrayBase) -> bool - { + fn eq(&self, other: &&ArrayBase) -> bool { self == ***other } } @@ -308,26 +288,26 @@ where A: PartialEq, D: Dimension, { - fn eq(&self, other: &ArrayBase) -> bool - { + fn eq(&self, other: &ArrayBase) -> bool { *self == **other } } impl From> for ArrayBase -where S: DataOwned +where + S: DataOwned, { /// Create a one-dimensional array from a boxed slice (no copying needed). /// /// **Panics** if the length is greater than `isize::MAX`. - fn from(b: Box<[A]>) -> Self - { + fn from(b: Box<[A]>) -> Self { Self::from_vec(b.into_vec()) } } impl From> for ArrayBase -where S: DataOwned +where + S: DataOwned, { /// Create a one-dimensional array from a vector (no copying needed). /// @@ -338,14 +318,14 @@ where S: DataOwned /// /// let array = Array::from(vec![1., 2., 3., 4.]); /// ``` - fn from(v: Vec
) -> Self - { + fn from(v: Vec) -> Self { Self::from_vec(v) } } impl FromIterator for ArrayBase -where S: DataOwned +where + S: DataOwned, { /// Create a one-dimensional array from an iterable. /// @@ -359,34 +339,35 @@ where S: DataOwned /// assert!(array == arr1(&[0, 1, 4, 9, 16])) /// ``` fn from_iter(iterable: I) -> ArrayBase - where I: IntoIterator + where + I: IntoIterator, { Self::from_iter(iterable) } } impl<'a, A, D> IntoIterator for &'a ArrayRef -where D: Dimension +where + D: Dimension, { type Item = &'a A; type IntoIter = Iter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { self.iter() } } impl<'a, A, D> IntoIterator for &'a mut ArrayRef -where D: Dimension +where + D: Dimension, { type Item = &'a mut A; type IntoIter = IterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { self.iter_mut() } } @@ -399,8 +380,7 @@ where type Item = &'a S::Elem; type IntoIter = Iter<'a, S::Elem, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { self.iter() } } @@ -413,32 +393,31 @@ where type Item = &'a mut S::Elem; type IntoIter = IterMut<'a, S::Elem, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { self.iter_mut() } } impl<'a, A, D> IntoIterator for ArrayView<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = &'a A; type IntoIter = Iter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { Iter::new(self) } } impl<'a, A, D> IntoIterator for ArrayViewMut<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = &'a mut A; type IntoIter = IterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { IterMut::new(self) } } @@ -449,8 +428,7 @@ where A: hash::Hash, { // Note: elements are hashed in the logical order - fn hash(&self, state: &mut H) - { + fn hash(&self, state: &mut H) { self.shape().hash(state); if let Some(self_s) = self.as_slice() { hash::Hash::hash_slice(self_s, state); @@ -475,8 +453,7 @@ where S::Elem: hash::Hash, { // Note: elements are hashed in the logical order - fn hash(&self, state: &mut H) - { + fn hash(&self, state: &mut H) { (**self).hash(state) } } @@ -517,13 +494,13 @@ pub const ARRAY_FORMAT_VERSION: u8 = 1u8; /// occur if `A` is zero-sized, because slices cannot contain more than /// `isize::MAX` number of bytes.) impl<'a, A, Slice: ?Sized> From<&'a Slice> for ArrayView<'a, A, Ix1> -where Slice: AsRef<[A]> +where + Slice: AsRef<[A]>, { /// Create a one-dimensional read-only array view of the data in `slice`. /// /// **Panics** if the slice length is greater than `isize::MAX`. - fn from(slice: &'a Slice) -> Self - { + fn from(slice: &'a Slice) -> Self { aview1(slice.as_ref()) } } @@ -533,11 +510,9 @@ where Slice: AsRef<[A]> /// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A /// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). /// **Panics** if N == 0 and the number of rows is greater than isize::MAX. -impl<'a, A, const M: usize, const N: usize> From<&'a [[A; N]; M]> for ArrayView<'a, A, Ix2> -{ +impl<'a, A, const M: usize, const N: usize> From<&'a [[A; N]; M]> for ArrayView<'a, A, Ix2> { /// Create a two-dimensional read-only array view of the data in `slice` - fn from(xs: &'a [[A; N]; M]) -> Self - { + fn from(xs: &'a [[A; N]; M]) -> Self { Self::from(&xs[..]) } } @@ -547,11 +522,9 @@ impl<'a, A, const M: usize, const N: usize> From<&'a [[A; N]; M]> for ArrayView< /// **Panics** if the product of non-zero axis lengths overflows `isize`. (This /// can only occur if A is zero-sized or if `N` is zero, because slices cannot /// contain more than `isize::MAX` number of bytes.) -impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> -{ +impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> { /// Create a two-dimensional read-only array view of the data in `slice` - fn from(xs: &'a [[A; N]]) -> Self - { + fn from(xs: &'a [[A; N]]) -> Self { aview2(xs) } } @@ -563,27 +536,23 @@ where D: Dimension, { /// Create a read-only array view of the array. - fn from(array: &'a ArrayBase) -> Self - { + fn from(array: &'a ArrayBase) -> Self { array.view() } } /// Implementation of `ArrayViewMut::from(&mut S)` where `S` is a slice or sliceable. impl<'a, A, Slice: ?Sized> From<&'a mut Slice> for ArrayViewMut<'a, A, Ix1> -where Slice: AsMut<[A]> +where + Slice: AsMut<[A]>, { /// Create a one-dimensional read-write array view of the data in `slice`. /// /// **Panics** if the slice length is greater than `isize::MAX`. - fn from(slice: &'a mut Slice) -> Self - { + fn from(slice: &'a mut Slice) -> Self { let xs = slice.as_mut(); if mem::size_of::() == 0 { - assert!( - xs.len() <= isize::MAX as usize, - "Slice length must fit in `isize`.", - ); + assert!(xs.len() <= isize::MAX as usize, "Slice length must fit in `isize`.",); } unsafe { Self::from_shape_ptr(xs.len(), xs.as_mut_ptr()) } } @@ -594,11 +563,9 @@ where Slice: AsMut<[A]> /// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A /// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). /// **Panics** if N == 0 and the number of rows is greater than isize::MAX. -impl<'a, A, const M: usize, const N: usize> From<&'a mut [[A; N]; M]> for ArrayViewMut<'a, A, Ix2> -{ +impl<'a, A, const M: usize, const N: usize> From<&'a mut [[A; N]; M]> for ArrayViewMut<'a, A, Ix2> { /// Create a two-dimensional read-write array view of the data in `slice` - fn from(xs: &'a mut [[A; N]; M]) -> Self - { + fn from(xs: &'a mut [[A; N]; M]) -> Self { Self::from(&mut xs[..]) } } @@ -608,21 +575,16 @@ impl<'a, A, const M: usize, const N: usize> From<&'a mut [[A; N]; M]> for ArrayV /// **Panics** if the product of non-zero axis lengths overflows `isize`. (This /// can only occur if `A` is zero-sized or if `N` is zero, because slices /// cannot contain more than `isize::MAX` number of bytes.) -impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> -{ +impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> { /// Create a two-dimensional read-write array view of the data in `slice` - fn from(xs: &'a mut [[A; N]]) -> Self - { + fn from(xs: &'a mut [[A; N]]) -> Self { let cols = N; let rows = xs.len(); let dim = Ix2(rows, cols); if size_of::() == 0 { dimension::size_of_shape_checked(&dim).expect("Product of non-zero axis lengths must not overflow isize."); } else if N == 0 { - assert!( - xs.len() <= isize::MAX as usize, - "Product of non-zero axis lengths must not overflow isize.", - ); + assert!(xs.len() <= isize::MAX as usize, "Product of non-zero axis lengths must not overflow isize.",); } // `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in @@ -641,17 +603,16 @@ where D: Dimension, { /// Create a read-write array view of the array. - fn from(array: &'a mut ArrayBase) -> Self - { + fn from(array: &'a mut ArrayBase) -> Self { array.view_mut() } } impl From> for ArcArray -where D: Dimension +where + D: Dimension, { - fn from(arr: Array) -> ArcArray - { + fn from(arr: Array) -> ArcArray { let data = OwnedArcRepr(Arc::new(arr.data)); // safe because: equivalent unmoved data, ptr and dims remain valid unsafe { ArrayBase::from_data_ptr(data, arr.parts.ptr).with_strides_dim(arr.parts.strides, arr.parts.dim) } @@ -680,7 +641,8 @@ where D: Dimension /// /// ``` pub trait AsArray<'a, A: 'a, D = Ix1>: Into> -where D: Dimension +where + D: Dimension, { } impl<'a, A: 'a, D, T> AsArray<'a, A, D> for T @@ -710,21 +672,18 @@ where { // NOTE: We can implement Default for non-zero dimensional array views by // using an empty slice, however we need a trait for nonzero Dimension. - fn default() -> Self - { + fn default() -> Self { ArrayBase::default(D::default()) } } #[cfg(test)] -mod tests -{ +mod tests { use crate::array; use alloc::vec; #[test] - fn test_eq_traits() - { + fn test_eq_traits() { let a = array![1, 2, 3]; let a_ref = &*a; let b = array![1, 2, 3]; diff --git a/src/backend/kernels_avx512.rs b/src/backend/kernels_avx512.rs index 7116a731..8b1bec34 100644 --- a/src/backend/kernels_avx512.rs +++ b/src/backend/kernels_avx512.rs @@ -251,12 +251,17 @@ pub fn nrm2_f64(x: &[f64]) -> f64 { #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] pub fn iamax_f32(x: &[f32]) -> (usize, f32) { - if x.is_empty() { return (0, 0.0); } + if x.is_empty() { + return (0, 0.0); + } let mut max_idx = 0; let mut max_val = x[0].abs(); for (i, &v) in x.iter().enumerate().skip(1) { let a = v.abs(); - if a > max_val { max_val = a; max_idx = i; } + if a > max_val { + max_val = a; + max_idx = i; + } } (max_idx, x[max_idx]) } @@ -267,12 +272,17 @@ pub fn iamax_f32(x: &[f32]) -> (usize, f32) { #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] pub fn iamax_f64(x: &[f64]) -> (usize, f64) { - if x.is_empty() { return (0, 0.0); } + if x.is_empty() { + return (0, 0.0); + } let mut max_idx = 0; let mut max_val = x[0].abs(); for (i, &v) in x.iter().enumerate().skip(1) { let a = v.abs(); - if a > max_val { max_val = a; max_idx = i; } + if a > max_val { + max_val = a; + max_idx = i; + } } (max_idx, x[max_idx]) } @@ -286,50 +296,66 @@ pub fn iamax_f64(x: &[f64]) -> (usize, f64) { /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn add_f32_scalar(a: &[f32], scalar: f32) -> Vec { ew_f32_s(a, scalar, EwOp::Add) } +pub fn add_f32_scalar(a: &[f32], scalar: f32) -> Vec { + ew_f32_s(a, scalar, EwOp::Add) +} /// Elementwise `out[i] = a[i] - scalar`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn sub_f32_scalar(a: &[f32], scalar: f32) -> Vec { ew_f32_s(a, scalar, EwOp::Sub) } +pub fn sub_f32_scalar(a: &[f32], scalar: f32) -> Vec { + ew_f32_s(a, scalar, EwOp::Sub) +} /// Elementwise `out[i] = a[i] * scalar`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn mul_f32_scalar(a: &[f32], scalar: f32) -> Vec { ew_f32_s(a, scalar, EwOp::Mul) } +pub fn mul_f32_scalar(a: &[f32], scalar: f32) -> Vec { + ew_f32_s(a, scalar, EwOp::Mul) +} /// Elementwise `out[i] = a[i] / scalar`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn div_f32_scalar(a: &[f32], scalar: f32) -> Vec { ew_f32_s(a, scalar, EwOp::Div) } +pub fn div_f32_scalar(a: &[f32], scalar: f32) -> Vec { + ew_f32_s(a, scalar, EwOp::Div) +} /// Elementwise `out[i] = a[i] + b[i]` (AVX-512 F32x16 kernel). /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn add_f32_vec(a: &[f32], b: &[f32]) -> Vec { ew_f32_v(a, b, EwOp::Add) } +pub fn add_f32_vec(a: &[f32], b: &[f32]) -> Vec { + ew_f32_v(a, b, EwOp::Add) +} /// Elementwise `out[i] = a[i] - b[i]`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn sub_f32_vec(a: &[f32], b: &[f32]) -> Vec { ew_f32_v(a, b, EwOp::Sub) } +pub fn sub_f32_vec(a: &[f32], b: &[f32]) -> Vec { + ew_f32_v(a, b, EwOp::Sub) +} /// Elementwise `out[i] = a[i] * b[i]`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn mul_f32_vec(a: &[f32], b: &[f32]) -> Vec { ew_f32_v(a, b, EwOp::Mul) } +pub fn mul_f32_vec(a: &[f32], b: &[f32]) -> Vec { + ew_f32_v(a, b, EwOp::Mul) +} /// Elementwise `out[i] = a[i] / b[i]`. /// # Safety /// Caller must ensure AVX-512F is available (`simd_caps().avx512f`). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn div_f32_vec(a: &[f32], b: &[f32]) -> Vec { ew_f32_v(a, b, EwOp::Div) } +pub fn div_f32_vec(a: &[f32], b: &[f32]) -> Vec { + ew_f32_v(a, b, EwOp::Div) +} // ═══════════════════════════════════════════════════════════════════ // Element-wise f64 — 8 functions (8-wide, compat types) @@ -337,34 +363,55 @@ pub fn div_f32_vec(a: &[f32], b: &[f32]) -> Vec { ew_f32_v(a, b, EwOp::Div) #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn add_f64_scalar(a: &[f64], scalar: f64) -> Vec { ew_f64_s(a, scalar, EwOp::Add) } +pub fn add_f64_scalar(a: &[f64], scalar: f64) -> Vec { + ew_f64_s(a, scalar, EwOp::Add) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn sub_f64_scalar(a: &[f64], scalar: f64) -> Vec { ew_f64_s(a, scalar, EwOp::Sub) } +pub fn sub_f64_scalar(a: &[f64], scalar: f64) -> Vec { + ew_f64_s(a, scalar, EwOp::Sub) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn mul_f64_scalar(a: &[f64], scalar: f64) -> Vec { ew_f64_s(a, scalar, EwOp::Mul) } +pub fn mul_f64_scalar(a: &[f64], scalar: f64) -> Vec { + ew_f64_s(a, scalar, EwOp::Mul) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn div_f64_scalar(a: &[f64], scalar: f64) -> Vec { ew_f64_s(a, scalar, EwOp::Div) } +pub fn div_f64_scalar(a: &[f64], scalar: f64) -> Vec { + ew_f64_s(a, scalar, EwOp::Div) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn add_f64_vec(a: &[f64], b: &[f64]) -> Vec { ew_f64_v(a, b, EwOp::Add) } +pub fn add_f64_vec(a: &[f64], b: &[f64]) -> Vec { + ew_f64_v(a, b, EwOp::Add) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn sub_f64_vec(a: &[f64], b: &[f64]) -> Vec { ew_f64_v(a, b, EwOp::Sub) } +pub fn sub_f64_vec(a: &[f64], b: &[f64]) -> Vec { + ew_f64_v(a, b, EwOp::Sub) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn mul_f64_vec(a: &[f64], b: &[f64]) -> Vec { ew_f64_v(a, b, EwOp::Mul) } +pub fn mul_f64_vec(a: &[f64], b: &[f64]) -> Vec { + ew_f64_v(a, b, EwOp::Mul) +} #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] -pub fn div_f64_vec(a: &[f64], b: &[f64]) -> Vec { ew_f64_v(a, b, EwOp::Div) } +pub fn div_f64_vec(a: &[f64], b: &[f64]) -> Vec { + ew_f64_v(a, b, EwOp::Div) +} // ─── Element-wise helpers (compat types) ───────────────────────── #[cfg(target_arch = "x86_64")] -enum EwOp { Add, Sub, Mul, Div } +enum EwOp { + Add, + Sub, + Mul, + Div, +} #[cfg(target_arch = "x86_64")] #[inline(always)] @@ -418,8 +465,7 @@ fn ew_f32_v(a: &[f32], b: &[f32], op: EwOp) -> Vec { let mut result = vec![0.0f32; n]; let mut i = 0; while i + 16 <= n { - apply_f32(F32x16::from_slice(&a[i..]), F32x16::from_slice(&b[i..]), &op) - .copy_to_slice(&mut result[i..]); + apply_f32(F32x16::from_slice(&a[i..]), F32x16::from_slice(&b[i..]), &op).copy_to_slice(&mut result[i..]); i += 16; } while i < n { @@ -464,8 +510,7 @@ fn ew_f64_v(a: &[f64], b: &[f64], op: EwOp) -> Vec { let mut result = vec![0.0f64; n]; let mut i = 0; while i + 8 <= n { - apply_f64(F64x8::from_slice(&a[i..]), F64x8::from_slice(&b[i..]), &op) - .copy_to_slice(&mut result[i..]); + apply_f64(F64x8::from_slice(&a[i..]), F64x8::from_slice(&b[i..]), &op).copy_to_slice(&mut result[i..]); i += 8; } while i < n { @@ -518,7 +563,11 @@ fn pack_a_f32(a: &[f32], lda: usize, mc: usize, kc: usize, i_start: usize, k_sta let rem = mc - ii; for p in 0..kc { for ir in 0..SGEMM_MR { - buf[idx] = if ir < rem { a[(i_start + ii + ir) * lda + (k_start + p)] } else { 0.0 }; + buf[idx] = if ir < rem { + a[(i_start + ii + ir) * lda + (k_start + p)] + } else { + 0.0 + }; idx += 1; } } @@ -543,7 +592,11 @@ fn pack_b_f32(b: &[f32], ldb: usize, kc: usize, nc: usize, k_start: usize, j_sta let rem = nc - jj; for p in 0..kc { for jr in 0..SGEMM_NR { - buf[idx] = if jr < rem { b[(k_start + p) * ldb + (j_start + jj + jr)] } else { 0.0 }; + buf[idx] = if jr < rem { + b[(k_start + p) * ldb + (j_start + jj + jr)] + } else { + 0.0 + }; idx += 1; } } @@ -558,14 +611,7 @@ fn pack_b_f32(b: &[f32], ldb: usize, kc: usize, nc: usize, k_start: usize, j_sta #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] unsafe fn sgemm_ukernel_6x16( - kc: usize, - alpha: f32, - a_packed: &[f32], - b_packed: &[f32], - c: &mut [f32], - ldc: usize, - mr_eff: usize, - nr_eff: usize, + kc: usize, alpha: f32, a_packed: &[f32], b_packed: &[f32], c: &mut [f32], ldc: usize, mr_eff: usize, nr_eff: usize, ) { let mut c0 = _mm512_setzero_ps(); let mut c1 = _mm512_setzero_ps(); @@ -615,10 +661,7 @@ unsafe fn sgemm_ukernel_6x16( #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] pub fn sgemm_blocked( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, c: &mut [f32], ldc: usize, ) { let mut a_packed = vec![0.0f32; SGEMM_MC * SGEMM_KC]; let mut b_packed = vec![0.0f32; SGEMM_KC * SGEMM_NC]; @@ -648,11 +691,14 @@ pub fn sgemm_blocked( // SAFETY: tier() verified AVX-512F, buffers sized correctly unsafe { sgemm_ukernel_6x16( - kc, alpha, + kc, + alpha, &a_packed[a_off..], &b_packed[b_off..], &mut c[(ii + ir) * ldc + (jj + jr)..], - ldc, mr_eff, nr_eff, + ldc, + mr_eff, + nr_eff, ); } jr += SGEMM_NR; @@ -687,7 +733,11 @@ fn pack_a_f64(a: &[f64], lda: usize, mc: usize, kc: usize, i_start: usize, k_sta let rem = mc - ii; for p in 0..kc { for ir in 0..DGEMM_MR { - buf[idx] = if ir < rem { a[(i_start + ii + ir) * lda + (k_start + p)] } else { 0.0 }; + buf[idx] = if ir < rem { + a[(i_start + ii + ir) * lda + (k_start + p)] + } else { + 0.0 + }; idx += 1; } } @@ -712,7 +762,11 @@ fn pack_b_f64(b: &[f64], ldb: usize, kc: usize, nc: usize, k_start: usize, j_sta let rem = nc - jj; for p in 0..kc { for jr in 0..DGEMM_NR { - buf[idx] = if jr < rem { b[(k_start + p) * ldb + (j_start + jj + jr)] } else { 0.0 }; + buf[idx] = if jr < rem { + b[(k_start + p) * ldb + (j_start + jj + jr)] + } else { + 0.0 + }; idx += 1; } } @@ -725,14 +779,7 @@ fn pack_b_f64(b: &[f64], ldb: usize, kc: usize, nc: usize, k_start: usize, j_sta #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] unsafe fn dgemm_ukernel_6x8( - kc: usize, - alpha: f64, - a_packed: &[f64], - b_packed: &[f64], - c: &mut [f64], - ldc: usize, - mr_eff: usize, - nr_eff: usize, + kc: usize, alpha: f64, a_packed: &[f64], b_packed: &[f64], c: &mut [f64], ldc: usize, mr_eff: usize, nr_eff: usize, ) { let mut c0 = _mm512_setzero_pd(); let mut c1 = _mm512_setzero_pd(); @@ -782,10 +829,7 @@ unsafe fn dgemm_ukernel_6x8( #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] pub fn dgemm_blocked( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, c: &mut [f64], ldc: usize, ) { let mut a_packed = vec![0.0f64; DGEMM_MC * DGEMM_KC]; let mut b_packed = vec![0.0f64; DGEMM_KC * DGEMM_NC]; @@ -814,11 +858,14 @@ pub fn dgemm_blocked( unsafe { dgemm_ukernel_6x8( - kc, alpha, + kc, + alpha, &a_packed[a_off..], &b_packed[b_off..], &mut c[(ii + ir) * ldc + (jj + jr)..], - ldc, mr_eff, nr_eff, + ldc, + mr_eff, + nr_eff, ); } jr += DGEMM_NR; diff --git a/src/backend/mkl.rs b/src/backend/mkl.rs index f8424850..be78601e 100644 --- a/src/backend/mkl.rs +++ b/src/backend/mkl.rs @@ -36,43 +36,28 @@ extern "C" { fn cblas_sasum(n: c_int, x: *const c_float, incx: c_int) -> c_float; fn cblas_dasum(n: c_int, x: *const c_double, incx: c_int) -> c_double; fn cblas_sgemm( - layout: c_int, transa: c_int, transb: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_float, a: *const c_float, lda: c_int, - b: *const c_float, ldb: c_int, - beta: c_float, c: *mut c_float, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, m: c_int, n: c_int, k: c_int, alpha: c_float, a: *const c_float, + lda: c_int, b: *const c_float, ldb: c_int, beta: c_float, c: *mut c_float, ldc: c_int, ); fn cblas_dgemm( - layout: c_int, transa: c_int, transb: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_double, a: *const c_double, lda: c_int, - b: *const c_double, ldb: c_int, - beta: c_double, c: *mut c_double, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, m: c_int, n: c_int, k: c_int, alpha: c_double, a: *const c_double, + lda: c_int, b: *const c_double, ldb: c_int, beta: c_double, c: *mut c_double, ldc: c_int, ); fn cblas_sgemv( - layout: c_int, trans: c_int, - m: c_int, n: c_int, - alpha: c_float, a: *const c_float, lda: c_int, - x: *const c_float, incx: c_int, - beta: c_float, y: *mut c_float, incy: c_int, + layout: c_int, trans: c_int, m: c_int, n: c_int, alpha: c_float, a: *const c_float, lda: c_int, + x: *const c_float, incx: c_int, beta: c_float, y: *mut c_float, incy: c_int, ); fn cblas_dgemv( - layout: c_int, trans: c_int, - m: c_int, n: c_int, - alpha: c_double, a: *const c_double, lda: c_int, - x: *const c_double, incx: c_int, - beta: c_double, y: *mut c_double, incy: c_int, + layout: c_int, trans: c_int, m: c_int, n: c_int, alpha: c_double, a: *const c_double, lda: c_int, + x: *const c_double, incx: c_int, beta: c_double, y: *mut c_double, incy: c_int, ); // Mixed-precision GEMM: BF16 inputs, F32 accumulator. // MKL takes `*const u16` for BF16 operands (no native bf16 type in C ABI). // Reference: oneAPI MKL Developer Reference, "cblas_gemm_bf16bf16f32". fn cblas_gemm_bf16bf16f32( - layout: c_int, transa: c_int, transb: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_float, a: *const u16, lda: c_int, - b: *const u16, ldb: c_int, - beta: c_float, c: *mut c_float, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, m: c_int, n: c_int, k: c_int, alpha: c_float, a: *const u16, + lda: c_int, b: *const u16, ldb: c_int, beta: c_float, c: *mut c_float, ldc: c_int, ); // Integer GEMM: i8 × i8 → i32. @@ -81,12 +66,8 @@ extern "C" { // matmul without zero-point correction. // Reference: oneAPI MKL Developer Reference, "cblas_gemm_s8s8s32". fn cblas_gemm_s8s8s32( - layout: c_int, transa: c_int, transb: c_int, offsetc: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_float, - a: *const i8, lda: c_int, oa: i8, - b: *const i8, ldb: c_int, ob: i8, - beta: c_float, c: *mut i32, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, offsetc: c_int, m: c_int, n: c_int, k: c_int, alpha: c_float, + a: *const i8, lda: c_int, oa: i8, b: *const i8, ldb: c_int, ob: i8, beta: c_float, c: *mut i32, ldc: c_int, co: *const i32, ); } @@ -98,13 +79,22 @@ extern "C" { extern "C" { pub fn LAPACKE_sgetrf(layout: c_int, m: c_int, n: c_int, a: *mut c_float, lda: c_int, ipiv: *mut c_int) -> c_int; pub fn LAPACKE_dgetrf(layout: c_int, m: c_int, n: c_int, a: *mut c_double, lda: c_int, ipiv: *mut c_int) -> c_int; - pub fn LAPACKE_sgetrs(layout: c_int, trans: u8, n: c_int, nrhs: c_int, a: *const c_float, lda: c_int, ipiv: *const c_int, b: *mut c_float, ldb: c_int) -> c_int; - pub fn LAPACKE_dgetrs(layout: c_int, trans: u8, n: c_int, nrhs: c_int, a: *const c_double, lda: c_int, ipiv: *const c_int, b: *mut c_double, ldb: c_int) -> c_int; + pub fn LAPACKE_sgetrs( + layout: c_int, trans: u8, n: c_int, nrhs: c_int, a: *const c_float, lda: c_int, ipiv: *const c_int, + b: *mut c_float, ldb: c_int, + ) -> c_int; + pub fn LAPACKE_dgetrs( + layout: c_int, trans: u8, n: c_int, nrhs: c_int, a: *const c_double, lda: c_int, ipiv: *const c_int, + b: *mut c_double, ldb: c_int, + ) -> c_int; pub fn LAPACKE_spotrf(layout: c_int, uplo: u8, n: c_int, a: *mut c_float, lda: c_int) -> c_int; pub fn LAPACKE_dpotrf(layout: c_int, uplo: u8, n: c_int, a: *mut c_double, lda: c_int) -> c_int; - pub fn LAPACKE_spotrs(layout: c_int, uplo: u8, n: c_int, nrhs: c_int, a: *const c_float, lda: c_int, b: *mut c_float, ldb: c_int) -> c_int; + pub fn LAPACKE_spotrs( + layout: c_int, uplo: u8, n: c_int, nrhs: c_int, a: *const c_float, lda: c_int, b: *mut c_float, ldb: c_int, + ) -> c_int; pub fn LAPACKE_sgeqrf(layout: c_int, m: c_int, n: c_int, a: *mut c_float, lda: c_int, tau: *mut c_float) -> c_int; - pub fn LAPACKE_dgeqrf(layout: c_int, m: c_int, n: c_int, a: *mut c_double, lda: c_int, tau: *mut c_double) -> c_int; + pub fn LAPACKE_dgeqrf(layout: c_int, m: c_int, n: c_int, a: *mut c_double, lda: c_int, tau: *mut c_double) + -> c_int; } // ═══════════════════════════════════════════════════════════════ @@ -144,7 +134,9 @@ pub const DFTI_NOT_INPLACE: c_int = 44; pub const DFTI_BACKWARD_SCALE: c_int = 5; extern "C" { - pub fn DftiCreateDescriptor(handle: *mut DftiDescriptorHandle, precision: c_int, domain: c_int, dimension: c_int, length: c_long) -> c_long; + pub fn DftiCreateDescriptor( + handle: *mut DftiDescriptorHandle, precision: c_int, domain: c_int, dimension: c_int, length: c_long, + ) -> c_long; pub fn DftiSetValue(handle: DftiDescriptorHandle, param: c_int, ...) -> c_long; pub fn DftiCommitDescriptor(handle: DftiDescriptorHandle) -> c_long; pub fn DftiComputeForward(handle: DftiDescriptorHandle, x_inout: *mut c_void, ...) -> c_long; @@ -201,73 +193,103 @@ pub fn asum_f64(x: &[f64]) -> f64 { } pub fn gemm_f32( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - beta: f32, c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, beta: f32, c: &mut [f32], + ldc: usize, ) { unsafe { cblas_sgemm( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), lda as c_int, - b.as_ptr(), ldb as c_int, - beta, c.as_mut_ptr(), ldc as c_int, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + lda as c_int, + b.as_ptr(), + ldb as c_int, + beta, + c.as_mut_ptr(), + ldc as c_int, ); } } pub fn gemm_f64( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - beta: f64, c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, beta: f64, c: &mut [f64], + ldc: usize, ) { unsafe { cblas_dgemm( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), lda as c_int, - b.as_ptr(), ldb as c_int, - beta, c.as_mut_ptr(), ldc as c_int, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + lda as c_int, + b.as_ptr(), + ldb as c_int, + beta, + c.as_mut_ptr(), + ldc as c_int, ); } } -pub fn gemv_f32( - m: usize, n: usize, - alpha: f32, a: &[f32], lda: usize, - x: &[f32], beta: f32, y: &mut [f32], -) { +pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) { unsafe { cblas_sgemv( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, - m as c_int, n as c_int, - alpha, a.as_ptr(), lda as c_int, - x.as_ptr(), 1, beta, y.as_mut_ptr(), 1, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + alpha, + a.as_ptr(), + lda as c_int, + x.as_ptr(), + 1, + beta, + y.as_mut_ptr(), + 1, ); } } -pub fn gemv_f64( - m: usize, n: usize, - alpha: f64, a: &[f64], lda: usize, - x: &[f64], beta: f64, y: &mut [f64], -) { +pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) { unsafe { cblas_dgemv( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, - m as c_int, n as c_int, - alpha, a.as_ptr(), lda as c_int, - x.as_ptr(), 1, beta, y.as_mut_ptr(), 1, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + alpha, + a.as_ptr(), + lda as c_int, + x.as_ptr(), + 1, + beta, + y.as_mut_ptr(), + 1, ); } } -pub const fn sgemm_nr() -> usize { 16 } -pub const fn sgemm_mr() -> usize { 6 } -pub const fn dgemm_nr() -> usize { 8 } -pub const fn dgemm_mr() -> usize { 6 } +pub const fn sgemm_nr() -> usize { + 16 +} +pub const fn sgemm_mr() -> usize { + 6 +} +pub const fn dgemm_nr() -> usize { + 8 +} +pub const fn dgemm_mr() -> usize { + 6 +} // ═══════════════════════════════════════════════════════════════ // Public ndarray-shaped GEMM API (Burn integration surface) @@ -306,16 +328,12 @@ pub enum MklError { impl core::fmt::Display for MklError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - MklError::ShapeMismatch { a_shape, b_shape } => write!( - f, - "MKL GEMM shape mismatch: A is {:?}, B is {:?}", - a_shape, b_shape - ), - MklError::OutputShapeMismatch { expected, got } => write!( - f, - "MKL GEMM output shape mismatch: expected {:?}, got {:?}", - expected, got - ), + MklError::ShapeMismatch { a_shape, b_shape } => { + write!(f, "MKL GEMM shape mismatch: A is {:?}, B is {:?}", a_shape, b_shape) + } + MklError::OutputShapeMismatch { expected, got } => { + write!(f, "MKL GEMM output shape mismatch: expected {:?}, got {:?}", expected, got) + } MklError::NonContiguous { which } => { write!(f, "MKL GEMM operand `{}` is not stride-compatible with CBLAS", which) } @@ -343,32 +361,42 @@ fn blas_layout(view: &crate::ArrayBase) -> Opt let cs = strides[1]; // Row-major: stride between rows is the leading dim, columns are stride 1. if cs == 1 && (rs >= cols as isize || rows <= 1) { - return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: CBLAS_NO_TRANS, ld: rs.max(1) as c_int }); + return Some(BlasLayout { + layout: CBLAS_ROW_MAJOR, + trans: CBLAS_NO_TRANS, + ld: rs.max(1) as c_int, + }); } // Column-major: stride between cols is the leading dim, rows are stride 1. // We expose this to CBLAS as a *row-major transposed* matrix so we keep a // single `layout` argument across all three operands. if rs == 1 && (cs >= rows as isize || cols <= 1) { - return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: 112 /* CblasTrans */, ld: cs.max(1) as c_int }); + return Some(BlasLayout { + layout: CBLAS_ROW_MAJOR, + trans: 112, /* CblasTrans */ + ld: cs.max(1) as c_int, + }); } None } /// `C := alpha * A * B + beta * C` for `f32` matrices via MKL `cblas_sgemm`. pub fn sgemm( - a: ArrayView2, - b: ArrayView2, - mut c: ArrayViewMut2, - alpha: f32, - beta: f32, + a: ArrayView2, b: ArrayView2, mut c: ArrayViewMut2, alpha: f32, beta: f32, ) -> Result<(), MklError> { let (m, k) = a.dim(); let (kb, n) = b.dim(); if k != kb { - return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + return Err(MklError::ShapeMismatch { + a_shape: a.dim(), + b_shape: b.dim(), + }); } if c.dim() != (m, n) { - return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + return Err(MklError::OutputShapeMismatch { + expected: (m, n), + got: c.dim(), + }); } let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; @@ -378,11 +406,20 @@ pub fn sgemm( } unsafe { cblas_sgemm( - lc.layout, la.trans, lb.trans, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), la.ld, - b.as_ptr(), lb.ld, - beta, c.as_mut_ptr(), lc.ld, + lc.layout, + la.trans, + lb.trans, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + la.ld, + b.as_ptr(), + lb.ld, + beta, + c.as_mut_ptr(), + lc.ld, ); } Ok(()) @@ -390,19 +427,21 @@ pub fn sgemm( /// `C := alpha * A * B + beta * C` for `f64` matrices via MKL `cblas_dgemm`. pub fn dgemm( - a: ArrayView2, - b: ArrayView2, - mut c: ArrayViewMut2, - alpha: f64, - beta: f64, + a: ArrayView2, b: ArrayView2, mut c: ArrayViewMut2, alpha: f64, beta: f64, ) -> Result<(), MklError> { let (m, k) = a.dim(); let (kb, n) = b.dim(); if k != kb { - return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + return Err(MklError::ShapeMismatch { + a_shape: a.dim(), + b_shape: b.dim(), + }); } if c.dim() != (m, n) { - return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + return Err(MklError::OutputShapeMismatch { + expected: (m, n), + got: c.dim(), + }); } let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; @@ -412,11 +451,20 @@ pub fn dgemm( } unsafe { cblas_dgemm( - lc.layout, la.trans, lb.trans, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), la.ld, - b.as_ptr(), lb.ld, - beta, c.as_mut_ptr(), lc.ld, + lc.layout, + la.trans, + lb.trans, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + la.ld, + b.as_ptr(), + lb.ld, + beta, + c.as_mut_ptr(), + lc.ld, ); } Ok(()) @@ -429,19 +477,22 @@ pub fn dgemm( /// builds the symbol is missing and linking will fail at runtime — there is /// no compile-time fallback. pub fn sgemm_bf16( - a: ArrayView2, - b: ArrayView2, - mut c: ArrayViewMut2, - alpha: f32, - beta: f32, + a: ArrayView2, b: ArrayView2, mut c: ArrayViewMut2, + alpha: f32, beta: f32, ) -> Result<(), MklError> { let (m, k) = a.dim(); let (kb, n) = b.dim(); if k != kb { - return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + return Err(MklError::ShapeMismatch { + a_shape: a.dim(), + b_shape: b.dim(), + }); } if c.dim() != (m, n) { - return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + return Err(MklError::OutputShapeMismatch { + expected: (m, n), + got: c.dim(), + }); } let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; @@ -452,12 +503,20 @@ pub fn sgemm_bf16( // BF16 is `#[repr(transparent)] (pub u16)`, so the pointer cast is sound. unsafe { cblas_gemm_bf16bf16f32( - lc.layout, la.trans, lb.trans, - m as c_int, n as c_int, k as c_int, + lc.layout, + la.trans, + lb.trans, + m as c_int, + n as c_int, + k as c_int, alpha, - a.as_ptr() as *const u16, la.ld, - b.as_ptr() as *const u16, lb.ld, - beta, c.as_mut_ptr(), lc.ld, + a.as_ptr() as *const u16, + la.ld, + b.as_ptr() as *const u16, + lb.ld, + beta, + c.as_mut_ptr(), + lc.ld, ); } Ok(()) @@ -469,18 +528,20 @@ pub fn sgemm_bf16( /// Note: alpha/beta are fixed at `1.0` / `0.0` for the simple `Burn`-style /// signature. If you need scaling, call the FFI directly. This requires /// Intel MKL >= 2018 (when integer GEMM was introduced). -pub fn sgemm_int8( - a: ArrayView2, - b: ArrayView2, - mut c: ArrayViewMut2, -) -> Result<(), MklError> { +pub fn sgemm_int8(a: ArrayView2, b: ArrayView2, mut c: ArrayViewMut2) -> Result<(), MklError> { let (m, k) = a.dim(); let (kb, n) = b.dim(); if k != kb { - return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + return Err(MklError::ShapeMismatch { + a_shape: a.dim(), + b_shape: b.dim(), + }); } if c.dim() != (m, n) { - return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + return Err(MklError::OutputShapeMismatch { + expected: (m, n), + got: c.dim(), + }); } let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; @@ -491,12 +552,23 @@ pub fn sgemm_int8( let co: i32 = 0; unsafe { cblas_gemm_s8s8s32( - lc.layout, la.trans, lb.trans, CBLAS_OFFSET_FIX, - m as c_int, n as c_int, k as c_int, + lc.layout, + la.trans, + lb.trans, + CBLAS_OFFSET_FIX, + m as c_int, + n as c_int, + k as c_int, 1.0_f32, - a.as_ptr(), la.ld, 0_i8, - b.as_ptr(), lb.ld, 0_i8, - 0.0_f32, c.as_mut_ptr(), lc.ld, + a.as_ptr(), + la.ld, + 0_i8, + b.as_ptr(), + lb.ld, + 0_i8, + 0.0_f32, + c.as_mut_ptr(), + lc.ld, &co as *const i32, ); } diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 020ff0d4..df71a701 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -14,8 +14,6 @@ pub mod native; #[cfg(target_arch = "x86_64")] pub(crate) mod kernels_avx512; - - #[cfg(feature = "intel-mkl")] pub mod mkl; #[cfg(feature = "openblas")] @@ -40,38 +38,20 @@ compile_error!("Features `intel-mkl` and `openblas` are mutually exclusive. Enab #[cfg(feature = "intel-mkl")] pub use mkl::{ - dot_f32, dot_f64, - axpy_f32, axpy_f64, - scal_f32, scal_f64, - nrm2_f32, nrm2_f64, - asum_f32, asum_f64, - gemm_f32, gemm_f64, - gemv_f32, gemv_f64, - sgemm_nr, sgemm_mr, dgemm_nr, dgemm_mr, + asum_f32, asum_f64, axpy_f32, axpy_f64, dgemm_mr, dgemm_nr, dot_f32, dot_f64, gemm_f32, gemm_f64, gemv_f32, + gemv_f64, nrm2_f32, nrm2_f64, scal_f32, scal_f64, sgemm_mr, sgemm_nr, }; #[cfg(all(feature = "openblas", not(feature = "intel-mkl")))] pub use openblas::{ - dot_f32, dot_f64, - axpy_f32, axpy_f64, - scal_f32, scal_f64, - nrm2_f32, nrm2_f64, - asum_f32, asum_f64, - gemm_f32, gemm_f64, - gemv_f32, gemv_f64, - sgemm_nr, sgemm_mr, dgemm_nr, dgemm_mr, + asum_f32, asum_f64, axpy_f32, axpy_f64, dgemm_mr, dgemm_nr, dot_f32, dot_f64, gemm_f32, gemm_f64, gemv_f32, + gemv_f64, nrm2_f32, nrm2_f64, scal_f32, scal_f64, sgemm_mr, sgemm_nr, }; #[cfg(not(any(feature = "intel-mkl", feature = "openblas")))] pub use native::{ - dot_f32, dot_f64, - axpy_f32, axpy_f64, - scal_f32, scal_f64, - nrm2_f32, nrm2_f64, - asum_f32, asum_f64, - gemm_f32, gemm_f64, - gemv_f32, gemv_f64, - sgemm_nr, sgemm_mr, dgemm_nr, dgemm_mr, + asum_f32, asum_f64, axpy_f32, axpy_f64, dgemm_mr, dgemm_nr, dot_f32, dot_f64, gemm_f32, gemm_f64, gemv_f32, + gemv_f64, nrm2_f32, nrm2_f64, scal_f32, scal_f64, sgemm_mr, sgemm_nr, }; // ─── BlasFloat: type-level dispatch for generic code ────────────── @@ -93,17 +73,11 @@ pub trait BlasFloat: num_traits::Float + Default + Send + Sync + 'static { fn backend_asum(x: &[Self]) -> Self; /// GEMM using the active backend. fn backend_gemm( - m: usize, n: usize, k: usize, - alpha: Self, a: &[Self], lda: usize, - b: &[Self], ldb: usize, - beta: Self, c: &mut [Self], ldc: usize, + m: usize, n: usize, k: usize, alpha: Self, a: &[Self], lda: usize, b: &[Self], ldb: usize, beta: Self, + c: &mut [Self], ldc: usize, ); /// GEMV using the active backend. - fn backend_gemv( - m: usize, n: usize, - alpha: Self, a: &[Self], lda: usize, - x: &[Self], beta: Self, y: &mut [Self], - ); + fn backend_gemv(m: usize, n: usize, alpha: Self, a: &[Self], lda: usize, x: &[Self], beta: Self, y: &mut [Self]); } impl BlasFloat for f32 { @@ -123,18 +97,12 @@ impl BlasFloat for f32 { asum_f32(x) } fn backend_gemm( - m: usize, n: usize, k: usize, - alpha: Self, a: &[Self], lda: usize, - b: &[Self], ldb: usize, - beta: Self, c: &mut [Self], ldc: usize, + m: usize, n: usize, k: usize, alpha: Self, a: &[Self], lda: usize, b: &[Self], ldb: usize, beta: Self, + c: &mut [Self], ldc: usize, ) { gemm_f32(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - fn backend_gemv( - m: usize, n: usize, - alpha: Self, a: &[Self], lda: usize, - x: &[Self], beta: Self, y: &mut [Self], - ) { + fn backend_gemv(m: usize, n: usize, alpha: Self, a: &[Self], lda: usize, x: &[Self], beta: Self, y: &mut [Self]) { gemv_f32(m, n, alpha, a, lda, x, beta, y); } } @@ -156,18 +124,12 @@ impl BlasFloat for f64 { asum_f64(x) } fn backend_gemm( - m: usize, n: usize, k: usize, - alpha: Self, a: &[Self], lda: usize, - b: &[Self], ldb: usize, - beta: Self, c: &mut [Self], ldc: usize, + m: usize, n: usize, k: usize, alpha: Self, a: &[Self], lda: usize, b: &[Self], ldb: usize, beta: Self, + c: &mut [Self], ldc: usize, ) { gemm_f64(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - fn backend_gemv( - m: usize, n: usize, - alpha: Self, a: &[Self], lda: usize, - x: &[Self], beta: Self, y: &mut [Self], - ) { + fn backend_gemv(m: usize, n: usize, alpha: Self, a: &[Self], lda: usize, x: &[Self], beta: Self, y: &mut [Self]) { gemv_f64(m, n, alpha, a, lda, x, beta, y); } } @@ -185,10 +147,8 @@ impl BlasFloat for f64 { /// `cblas_sgemm` equivalent — pure Rust SIMD-dispatched f32 GEMM. #[inline] pub fn cblas_sgemm( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - beta: f32, c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, beta: f32, c: &mut [f32], + ldc: usize, ) { gemm_f32(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) } @@ -196,10 +156,8 @@ pub fn cblas_sgemm( /// `cblas_dgemm` equivalent — pure Rust SIMD-dispatched f64 GEMM. #[inline] pub fn cblas_dgemm( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - beta: f64, c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, beta: f64, c: &mut [f64], + ldc: usize, ) { gemm_f64(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) } @@ -215,10 +173,7 @@ pub fn cblas_dgemm( /// Same signature across all paths. #[inline] #[allow(clippy::needless_return)] -pub fn gemm_i8( - a: &[u8], b: &[i8], c: &mut [i32], - m: usize, n: usize, k: usize, -) { +pub fn gemm_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { // VNNI path (Ice Lake, Sapphire Rapids, Zen 4) — includes AMX fallback #[cfg(feature = "std")] { @@ -239,10 +194,7 @@ pub fn gemm_i8( /// `ndarray::hpc::quantized::BF16`). #[inline] #[allow(clippy::needless_return)] -pub fn gemm_bf16( - a: &[u16], b: &[u16], c: &mut [f32], - m: usize, n: usize, k: usize, -) { +pub fn gemm_bf16(a: &[u16], b: &[u16], c: &mut [f32], m: usize, n: usize, k: usize) { // Reinterpret u16 slices as BF16 slices (repr(transparent)) #[cfg(feature = "std")] { @@ -250,9 +202,8 @@ pub fn gemm_bf16( // SAFETY: BF16 is #[repr(transparent)] over u16 core::slice::from_raw_parts(a.as_ptr() as *const crate::hpc::quantized::BF16, a.len()) }; - let b_bf16: &[crate::hpc::quantized::BF16] = unsafe { - core::slice::from_raw_parts(b.as_ptr() as *const crate::hpc::quantized::BF16, b.len()) - }; + let b_bf16: &[crate::hpc::quantized::BF16] = + unsafe { core::slice::from_raw_parts(b.as_ptr() as *const crate::hpc::quantized::BF16, b.len()) }; crate::hpc::quantized::bf16_gemm_f32(a_bf16, b_bf16, c, m, n, k, 1.0, 0.0); return; } @@ -265,19 +216,13 @@ pub fn gemm_bf16( /// CBLAS-compat alias for INT8 GEMM. #[inline] -pub fn cblas_gemm_s8s8s32( - a: &[u8], b: &[i8], c: &mut [i32], - m: usize, n: usize, k: usize, -) { +pub fn cblas_gemm_s8s8s32(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { gemm_i8(a, b, c, m, n, k) } /// CBLAS-compat alias for BF16 GEMM. #[inline] -pub fn cblas_gemm_bf16bf16f32( - a: &[u16], b: &[u16], c: &mut [f32], - m: usize, n: usize, k: usize, -) { +pub fn cblas_gemm_bf16bf16f32(a: &[u16], b: &[u16], c: &mut [f32], m: usize, n: usize, k: usize) { gemm_bf16(a, b, c, m, n, k) } @@ -293,9 +238,8 @@ pub fn cblas_gemm_bf16bf16f32( #[cfg(target_arch = "x86_64")] pub use kernels_avx512::{ - add_f32_vec, sub_f32_vec, mul_f32_vec, div_f32_vec, - add_f32_scalar, sub_f32_scalar, mul_f32_scalar, div_f32_scalar, - iamax_f32, iamax_f64, + add_f32_scalar, add_f32_vec, div_f32_scalar, div_f32_vec, iamax_f32, iamax_f64, mul_f32_scalar, mul_f32_vec, + sub_f32_scalar, sub_f32_vec, }; // ─── Slice-level ops by dtype (unified re-exports) ────────────── @@ -304,23 +248,15 @@ pub use kernels_avx512::{ // Integer: simd_int_ops. Half: simd_half. Float: kernels_avx512 + reductions. #[cfg(feature = "std")] -pub use crate::simd_int_ops::{ - add_i8, sub_i8, add_i16, - dot_i8, dot_i16, - min_i8, max_i8, -}; +pub use crate::simd_int_ops::{add_i16, add_i8, dot_i16, dot_i8, max_i8, min_i8, sub_i8}; #[cfg(feature = "std")] pub use crate::simd_half::{ - add_bf16_inplace, mul_bf16_inplace, - add_f16_inplace, mul_f16_inplace, - cast_bf16_to_f32_batch, cast_f16_to_f32_batch, - cast_f32_to_bf16_batch, cast_f32_to_f16_batch, + add_bf16_inplace, add_f16_inplace, cast_bf16_to_f32_batch, cast_f16_to_f32_batch, cast_f32_to_bf16_batch, + cast_f32_to_f16_batch, mul_bf16_inplace, mul_f16_inplace, }; #[cfg(feature = "std")] pub use crate::hpc::reductions::{ - sum_f32, sum_f64, mean_f32, mean_f64, - max_f32, min_f32, argmax_f32, argmin_f32, - nrm2_f32 as nrm2_f32_simd, + argmax_f32, argmin_f32, max_f32, mean_f32, mean_f64, min_f32, nrm2_f32 as nrm2_f32_simd, sum_f32, sum_f64, }; diff --git a/src/backend/native.rs b/src/backend/native.rs index 164a37ca..102180d1 100644 --- a/src/backend/native.rs +++ b/src/backend/native.rs @@ -13,12 +13,18 @@ use std::sync::LazyLock; // ─── Tier detection: happens ONCE, at first access ───────────────── #[derive(Clone, Copy, PartialEq)] -enum Tier { Avx512, Avx2, Scalar } +enum Tier { + Avx512, + Avx2, + Scalar, +} static TIER: LazyLock = LazyLock::new(|| { #[cfg(target_arch = "x86_64")] { - if is_x86_feature_detected!("avx512f") { return Tier::Avx512; } + if is_x86_feature_detected!("avx512f") { + return Tier::Avx512; + } if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { return Tier::Avx2; } @@ -27,7 +33,9 @@ static TIER: LazyLock = LazyLock::new(|| { }); #[inline(always)] -fn tier() -> Tier { *TIER } +fn tier() -> Tier { + *TIER +} // ─── Runtime GEMM tile constants ─────────────────────────────────── @@ -36,7 +44,7 @@ fn tier() -> Tier { *TIER } pub fn sgemm_nr() -> usize { match tier() { Tier::Avx512 => 16, - Tier::Avx2 => 8, + Tier::Avx2 => 8, Tier::Scalar => 4, } } @@ -45,7 +53,7 @@ pub fn sgemm_nr() -> usize { pub fn sgemm_mr() -> usize { match tier() { Tier::Avx512 => 6, - Tier::Avx2 => 6, + Tier::Avx2 => 6, Tier::Scalar => 4, } } @@ -54,7 +62,7 @@ pub fn sgemm_mr() -> usize { pub fn dgemm_nr() -> usize { match tier() { Tier::Avx512 => 8, - Tier::Avx2 => 4, + Tier::Avx2 => 4, Tier::Scalar => 4, } } @@ -63,7 +71,7 @@ pub fn dgemm_nr() -> usize { pub fn dgemm_mr() -> usize { match tier() { Tier::Avx512 => 6, - Tier::Avx2 => 6, + Tier::Avx2 => 6, Tier::Scalar => 4, } } @@ -194,10 +202,8 @@ dispatch!( /// The custom AVX-512 kernels in `kernels_avx512` are retained for /// non-GEMM paths (Hamming, bitwise) where matrixmultiply has no equivalent. pub fn gemm_f32( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - beta: f32, c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, beta: f32, c: &mut [f32], + ldc: usize, ) { if m == 0 || n == 0 { return; @@ -206,22 +212,28 @@ pub fn gemm_f32( // Row-major: row stride = lda/ldb/ldc, col stride = 1. unsafe { matrixmultiply::sgemm( - m, k, n, + m, + k, + n, alpha, - a.as_ptr(), lda as isize, 1, - b.as_ptr(), ldb as isize, 1, + a.as_ptr(), + lda as isize, + 1, + b.as_ptr(), + ldb as isize, + 1, beta, - c.as_mut_ptr(), ldc as isize, 1, + c.as_mut_ptr(), + ldc as isize, + 1, ); } } /// GEMM: C = alpha * A * B + beta * C (f64, row-major). pub fn gemm_f64( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - beta: f64, c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, beta: f64, c: &mut [f64], + ldc: usize, ) { if m == 0 || n == 0 { return; @@ -229,12 +241,20 @@ pub fn gemm_f64( // SAFETY: same as sgemm — valid slices, row-major strides. unsafe { matrixmultiply::dgemm( - m, k, n, + m, + k, + n, alpha, - a.as_ptr(), lda as isize, 1, - b.as_ptr(), ldb as isize, 1, + a.as_ptr(), + lda as isize, + 1, + b.as_ptr(), + ldb as isize, + 1, beta, - c.as_mut_ptr(), ldc as isize, 1, + c.as_mut_ptr(), + ldc as isize, + 1, ); } } @@ -242,20 +262,12 @@ pub fn gemm_f64( // ─── GEMV dispatch ─────────────────────────────────────────────── /// GEMV: y = alpha * A * x + beta * y (f32) -pub fn gemv_f32( - m: usize, n: usize, - alpha: f32, a: &[f32], lda: usize, - x: &[f32], beta: f32, y: &mut [f32], -) { +pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) { scalar::gemv_f32(m, n, alpha, a, lda, x, beta, y); } /// GEMV: y = alpha * A * x + beta * y (f64) -pub fn gemv_f64( - m: usize, n: usize, - alpha: f64, a: &[f64], lda: usize, - x: &[f64], beta: f64, y: &mut [f64], -) { +pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) { scalar::gemv_f64(m, n, alpha, a, lda, x, beta, y); } @@ -271,8 +283,7 @@ mod scalar { let mut sum = 0.0f32; let mut i = 0; while i + 4 <= n { - sum += x[i] * y[i] + x[i + 1] * y[i + 1] - + x[i + 2] * y[i + 2] + x[i + 3] * y[i + 3]; + sum += x[i] * y[i] + x[i + 1] * y[i + 1] + x[i + 2] * y[i + 2] + x[i + 3] * y[i + 3]; i += 4; } while i < n { @@ -287,8 +298,7 @@ mod scalar { let mut sum = 0.0f64; let mut i = 0; while i + 4 <= n { - sum += x[i] * y[i] + x[i + 1] * y[i + 1] - + x[i + 2] * y[i + 2] + x[i + 3] * y[i + 3]; + sum += x[i] * y[i] + x[i + 1] * y[i + 1] + x[i + 2] * y[i + 2] + x[i + 3] * y[i + 3]; i += 4; } while i < n { @@ -359,10 +369,8 @@ mod scalar { /// Tiled GEMM: C = alpha * A * B + beta * C (scalar reference) #[allow(dead_code)] pub fn gemm_f32_tiled( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - beta: f32, c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, beta: f32, + c: &mut [f32], ldc: usize, ) { const TILE: usize = 64; @@ -408,10 +416,8 @@ mod scalar { /// Tiled GEMM (f64, scalar reference) #[allow(dead_code)] pub fn gemm_f64_tiled( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - beta: f64, c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, beta: f64, + c: &mut [f64], ldc: usize, ) { const TILE: usize = 64; @@ -454,11 +460,7 @@ mod scalar { } } - pub fn gemv_f32( - m: usize, n: usize, - alpha: f32, a: &[f32], lda: usize, - x: &[f32], beta: f32, y: &mut [f32], - ) { + pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) { for i in 0..m { let mut sum = 0.0f32; for j in 0..n { @@ -468,11 +470,7 @@ mod scalar { } } - pub fn gemv_f64( - m: usize, n: usize, - alpha: f64, a: &[f64], lda: usize, - x: &[f64], beta: f64, y: &mut [f64], - ) { + pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) { for i in 0..m { let mut sum = 0.0f64; for j in 0..n { @@ -537,12 +535,24 @@ mod avx2 { } // No AVX2 specialization — fall through to scalar - pub fn scal_f32(alpha: f32, x: &mut [f32]) { super::scalar::scal_f32(alpha, x); } - pub fn scal_f64(alpha: f64, x: &mut [f64]) { super::scalar::scal_f64(alpha, x); } - pub fn nrm2_f32(x: &[f32]) -> f32 { super::scalar::nrm2_f32(x) } - pub fn nrm2_f64(x: &[f64]) -> f64 { super::scalar::nrm2_f64(x) } - pub fn asum_f32(x: &[f32]) -> f32 { super::scalar::asum_f32(x) } - pub fn asum_f64(x: &[f64]) -> f64 { super::scalar::asum_f64(x) } + pub fn scal_f32(alpha: f32, x: &mut [f32]) { + super::scalar::scal_f32(alpha, x); + } + pub fn scal_f64(alpha: f64, x: &mut [f64]) { + super::scalar::scal_f64(alpha, x); + } + pub fn nrm2_f32(x: &[f32]) -> f32 { + super::scalar::nrm2_f32(x) + } + pub fn nrm2_f64(x: &[f64]) -> f64 { + super::scalar::nrm2_f64(x) + } + pub fn asum_f32(x: &[f32]) -> f32 { + super::scalar::asum_f32(x) + } + pub fn asum_f64(x: &[f64]) -> f64 { + super::scalar::asum_f64(x) + } // ── AVX2 intrinsic implementations ───────────────────────────── diff --git a/src/backend/openblas.rs b/src/backend/openblas.rs index 81299efe..4885345e 100644 --- a/src/backend/openblas.rs +++ b/src/backend/openblas.rs @@ -23,32 +23,20 @@ extern "C" { fn cblas_sasum(n: c_int, x: *const c_float, incx: c_int) -> c_float; fn cblas_dasum(n: c_int, x: *const c_double, incx: c_int) -> c_double; fn cblas_sgemm( - layout: c_int, transa: c_int, transb: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_float, a: *const c_float, lda: c_int, - b: *const c_float, ldb: c_int, - beta: c_float, c: *mut c_float, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, m: c_int, n: c_int, k: c_int, alpha: c_float, a: *const c_float, + lda: c_int, b: *const c_float, ldb: c_int, beta: c_float, c: *mut c_float, ldc: c_int, ); fn cblas_dgemm( - layout: c_int, transa: c_int, transb: c_int, - m: c_int, n: c_int, k: c_int, - alpha: c_double, a: *const c_double, lda: c_int, - b: *const c_double, ldb: c_int, - beta: c_double, c: *mut c_double, ldc: c_int, + layout: c_int, transa: c_int, transb: c_int, m: c_int, n: c_int, k: c_int, alpha: c_double, a: *const c_double, + lda: c_int, b: *const c_double, ldb: c_int, beta: c_double, c: *mut c_double, ldc: c_int, ); fn cblas_sgemv( - layout: c_int, trans: c_int, - m: c_int, n: c_int, - alpha: c_float, a: *const c_float, lda: c_int, - x: *const c_float, incx: c_int, - beta: c_float, y: *mut c_float, incy: c_int, + layout: c_int, trans: c_int, m: c_int, n: c_int, alpha: c_float, a: *const c_float, lda: c_int, + x: *const c_float, incx: c_int, beta: c_float, y: *mut c_float, incy: c_int, ); fn cblas_dgemv( - layout: c_int, trans: c_int, - m: c_int, n: c_int, - alpha: c_double, a: *const c_double, lda: c_int, - x: *const c_double, incx: c_int, - beta: c_double, y: *mut c_double, incy: c_int, + layout: c_int, trans: c_int, m: c_int, n: c_int, alpha: c_double, a: *const c_double, lda: c_int, + x: *const c_double, incx: c_int, beta: c_double, y: *mut c_double, incy: c_int, ); } @@ -100,73 +88,103 @@ pub fn asum_f64(x: &[f64]) -> f64 { } pub fn gemm_f32( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - beta: f32, c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, beta: f32, c: &mut [f32], + ldc: usize, ) { // SAFETY: caller guarantees a is m×k (stride lda), b is k×n (stride ldb), // c is m×n (stride ldc), all row-major. unsafe { cblas_sgemm( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), lda as c_int, - b.as_ptr(), ldb as c_int, - beta, c.as_mut_ptr(), ldc as c_int, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + lda as c_int, + b.as_ptr(), + ldb as c_int, + beta, + c.as_mut_ptr(), + ldc as c_int, ); } } pub fn gemm_f64( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - beta: f64, c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, beta: f64, c: &mut [f64], + ldc: usize, ) { unsafe { cblas_dgemm( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS, - m as c_int, n as c_int, k as c_int, - alpha, a.as_ptr(), lda as c_int, - b.as_ptr(), ldb as c_int, - beta, c.as_mut_ptr(), ldc as c_int, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + k as c_int, + alpha, + a.as_ptr(), + lda as c_int, + b.as_ptr(), + ldb as c_int, + beta, + c.as_mut_ptr(), + ldc as c_int, ); } } -pub fn gemv_f32( - m: usize, n: usize, - alpha: f32, a: &[f32], lda: usize, - x: &[f32], beta: f32, y: &mut [f32], -) { +pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) { unsafe { cblas_sgemv( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, - m as c_int, n as c_int, - alpha, a.as_ptr(), lda as c_int, - x.as_ptr(), 1, beta, y.as_mut_ptr(), 1, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + alpha, + a.as_ptr(), + lda as c_int, + x.as_ptr(), + 1, + beta, + y.as_mut_ptr(), + 1, ); } } -pub fn gemv_f64( - m: usize, n: usize, - alpha: f64, a: &[f64], lda: usize, - x: &[f64], beta: f64, y: &mut [f64], -) { +pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) { unsafe { cblas_dgemv( - CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, - m as c_int, n as c_int, - alpha, a.as_ptr(), lda as c_int, - x.as_ptr(), 1, beta, y.as_mut_ptr(), 1, + CBLAS_ROW_MAJOR, + CBLAS_NO_TRANS, + m as c_int, + n as c_int, + alpha, + a.as_ptr(), + lda as c_int, + x.as_ptr(), + 1, + beta, + y.as_mut_ptr(), + 1, ); } } // Tile size constants (not meaningful for FFI, but needed for API compat) -pub const fn sgemm_nr() -> usize { 16 } -pub const fn sgemm_mr() -> usize { 6 } -pub const fn dgemm_nr() -> usize { 8 } -pub const fn dgemm_mr() -> usize { 6 } +pub const fn sgemm_nr() -> usize { + 16 +} +pub const fn sgemm_mr() -> usize { + 6 +} +pub const fn dgemm_nr() -> usize { + 8 +} +pub const fn dgemm_mr() -> usize { + 6 +} diff --git a/src/data_repr.rs b/src/data_repr.rs index 4041c192..be599be4 100644 --- a/src/data_repr.rs +++ b/src/data_repr.rs @@ -21,17 +21,14 @@ use rawpointer::PointerExt; // transmutable A -> B. #[derive(Debug)] #[repr(C)] -pub struct OwnedRepr -{ +pub struct OwnedRepr { ptr: NonNull, len: usize, capacity: usize, } -impl OwnedRepr -{ - pub(crate) fn from(v: Vec) -> Self - { +impl OwnedRepr { + pub(crate) fn from(v: Vec) -> Self { let mut v = ManuallyDrop::new(v); let len = v.len(); let capacity = v.capacity(); @@ -39,34 +36,28 @@ impl OwnedRepr Self { ptr, len, capacity } } - pub(crate) fn into_vec(self) -> Vec - { + pub(crate) fn into_vec(self) -> Vec { ManuallyDrop::new(self).take_as_vec() } - pub(crate) fn as_slice(&self) -> &[A] - { + pub(crate) fn as_slice(&self) -> &[A] { unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } } - pub(crate) fn len(&self) -> usize - { + pub(crate) fn len(&self) -> usize { self.len } - pub(crate) fn as_ptr(&self) -> *const A - { + pub(crate) fn as_ptr(&self) -> *const A { self.ptr.as_ptr() } - pub(crate) fn as_nonnull_mut(&mut self) -> NonNull - { + pub(crate) fn as_nonnull_mut(&mut self) -> NonNull { self.ptr } /// Return end pointer - pub(crate) fn as_end_nonnull(&self) -> NonNull - { + pub(crate) fn as_end_nonnull(&self) -> NonNull { unsafe { self.ptr.add(self.len) } } @@ -76,8 +67,7 @@ impl OwnedRepr /// /// Note that existing pointers into the data are invalidated #[must_use = "must use new pointer to update existing pointers"] - pub(crate) fn reserve(&mut self, additional: usize) -> NonNull - { + pub(crate) fn reserve(&mut self, additional: usize) -> NonNull { self.modify_as_vec(|mut v| { v.reserve(additional); v @@ -90,15 +80,13 @@ impl OwnedRepr /// ## Safety /// /// The first `new_len` elements of the data should be valid. - pub(crate) unsafe fn set_len(&mut self, new_len: usize) - { + pub(crate) unsafe fn set_len(&mut self, new_len: usize) { debug_assert!(new_len <= self.capacity); self.len = new_len; } /// Return the length (number of elements in total) - pub(crate) fn release_all_elements(&mut self) -> usize - { + pub(crate) fn release_all_elements(&mut self) -> usize { let ret = self.len; self.len = 0; ret @@ -110,8 +98,7 @@ impl OwnedRepr /// /// Caller must ensure the two types have the same representation. /// **Panics** if sizes don't match (which is not a sufficient check). - pub(crate) unsafe fn data_subst(self) -> OwnedRepr - { + pub(crate) unsafe fn data_subst(self) -> OwnedRepr { // necessary but not sufficient check assert_eq!(mem::size_of::(), mem::size_of::()); let self_ = ManuallyDrop::new(self); @@ -122,14 +109,12 @@ impl OwnedRepr } } - fn modify_as_vec(&mut self, f: impl FnOnce(Vec) -> Vec) - { + fn modify_as_vec(&mut self, f: impl FnOnce(Vec) -> Vec) { let v = self.take_as_vec(); *self = Self::from(f(v)); } - fn take_as_vec(&mut self) -> Vec - { + fn take_as_vec(&mut self) -> Vec { let capacity = self.capacity; let len = self.len; self.len = 0; @@ -139,15 +124,14 @@ impl OwnedRepr } impl Clone for OwnedRepr -where A: Clone +where + A: Clone, { - fn clone(&self) -> Self - { + fn clone(&self) -> Self { Self::from(self.as_slice().to_owned()) } - fn clone_from(&mut self, other: &Self) - { + fn clone_from(&mut self, other: &Self) { let mut v = self.take_as_vec(); let other = other.as_slice(); @@ -161,10 +145,8 @@ where A: Clone } } -impl Drop for OwnedRepr -{ - fn drop(&mut self) - { +impl Drop for OwnedRepr { + fn drop(&mut self) { if self.capacity > 0 { // correct because: If the elements don't need dropping, an // empty Vec is ok. Only the Vec's allocation needs dropping. diff --git a/src/data_traits.rs b/src/data_traits.rs index a0b33ea1..3ae0f409 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -35,8 +35,7 @@ use crate::{ArcArray, Array, ArrayBase, ArrayRef, CowRepr, Dimension, OwnedArcRe /// Traits in Rust can serve many different roles. This trait is public because /// it is used as a bound on public methods. #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait RawData: Sized -{ +pub unsafe trait RawData: Sized { /// The array element type. type Elem; @@ -52,8 +51,7 @@ pub unsafe trait RawData: Sized /// /// ***Internal trait, see `RawData`.*** #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait RawDataMut: RawData -{ +pub unsafe trait RawDataMut: RawData { /// If possible, ensures that the array has unique access to its data. /// /// The implementer must ensure that if the input is contiguous, then the @@ -81,15 +79,13 @@ pub unsafe trait RawDataMut: RawData /// /// ***Internal trait, see `RawData`.*** #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait RawDataClone: RawData -{ +pub unsafe trait RawDataClone: RawData { #[doc(hidden)] /// Unsafe because, `ptr` must point inside the current storage. unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull); #[doc(hidden)] - unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull - { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull { let (data, ptr) = other.clone_with_ptr(ptr); *self = data; ptr @@ -102,8 +98,7 @@ pub unsafe trait RawDataClone: RawData /// /// ***Internal trait, see `RawData`.*** #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait Data: RawData -{ +pub unsafe trait Data: RawData { /// Converts the array to a uniquely owned array, cloning elements if necessary. #[doc(hidden)] #[allow(clippy::wrong_self_convention)] @@ -116,7 +111,8 @@ pub unsafe trait Data: RawData /// cloning the array elements. Otherwise, returns `self_` unchanged. #[doc(hidden)] fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension; + where + D: Dimension; /// Return a shared ownership (copy on write) array based on the existing one, /// cloning elements if necessary. @@ -145,8 +141,7 @@ pub unsafe trait Data: RawData // the data is unique. You are also guaranteeing that `try_is_unique` always // returns `Some(_)`. #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait DataMut: Data + RawDataMut -{ +pub unsafe trait DataMut: Data + RawDataMut { /// Ensures that the array has unique access to its data. #[doc(hidden)] #[inline] @@ -162,48 +157,40 @@ pub unsafe trait DataMut: Data + RawDataMut #[doc(hidden)] #[inline] #[allow(clippy::wrong_self_convention)] // mut needed for Arc types - fn is_unique(&mut self) -> bool - { + fn is_unique(&mut self) -> bool { self.try_is_unique().unwrap() } } -unsafe impl RawData for RawViewRepr<*const A> -{ +unsafe impl RawData for RawViewRepr<*const A> { type Elem = A; #[inline(always)] - fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool { true } private_impl! {} } -unsafe impl RawDataClone for RawViewRepr<*const A> -{ - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { +unsafe impl RawDataClone for RawViewRepr<*const A> { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { (*self, ptr) } } -unsafe impl RawData for RawViewRepr<*mut A> -{ +unsafe impl RawData for RawViewRepr<*mut A> { type Elem = A; #[inline(always)] - fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool { true } private_impl! {} } -unsafe impl RawDataMut for RawViewRepr<*mut A> -{ +unsafe impl RawDataMut for RawViewRepr<*mut A> { #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -213,26 +200,21 @@ unsafe impl RawDataMut for RawViewRepr<*mut A> } #[inline] - fn try_is_unique(&mut self) -> Option - { + fn try_is_unique(&mut self) -> Option { None } } -unsafe impl RawDataClone for RawViewRepr<*mut A> -{ - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { +unsafe impl RawDataClone for RawViewRepr<*mut A> { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { (*self, ptr) } } -unsafe impl RawData for OwnedArcRepr -{ +unsafe impl RawData for OwnedArcRepr { type Elem = A; - fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { self.0._is_pointer_inbounds(self_ptr) } @@ -241,7 +223,8 @@ unsafe impl RawData for OwnedArcRepr // NOTE: Copy on write unsafe impl RawDataMut for OwnedArcRepr -where A: Clone +where + A: Clone, { fn try_ensure_unique(self_: &mut ArrayBase) where @@ -270,14 +253,12 @@ where A: Clone } } - fn try_is_unique(&mut self) -> Option - { + fn try_is_unique(&mut self) -> Option { Some(Arc::get_mut(&mut self.0).is_some()) } } -unsafe impl Data for OwnedArcRepr -{ +unsafe impl Data for OwnedArcRepr { fn into_owned(mut self_: ArrayBase) -> Array where A: Clone, @@ -292,7 +273,8 @@ unsafe impl Data for OwnedArcRepr } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension + where + D: Dimension, { match Arc::try_unwrap(self_.data.0) { Ok(owned_data) => unsafe { @@ -322,21 +304,17 @@ unsafe impl Data for OwnedArcRepr unsafe impl DataMut for OwnedArcRepr where A: Clone {} -unsafe impl RawDataClone for OwnedArcRepr -{ - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { +unsafe impl RawDataClone for OwnedArcRepr { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { // pointer is preserved (self.clone(), ptr) } } -unsafe impl RawData for OwnedRepr -{ +unsafe impl RawData for OwnedRepr { type Elem = A; - fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { let slc = self.as_slice(); let ptr = slc.as_ptr() as *mut A; let end = unsafe { ptr.add(slc.len()) }; @@ -346,8 +324,7 @@ unsafe impl RawData for OwnedRepr private_impl! {} } -unsafe impl RawDataMut for OwnedRepr -{ +unsafe impl RawDataMut for OwnedRepr { #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -357,14 +334,12 @@ unsafe impl RawDataMut for OwnedRepr } #[inline] - fn try_is_unique(&mut self) -> Option - { + fn try_is_unique(&mut self) -> Option { Some(true) } } -unsafe impl Data for OwnedRepr -{ +unsafe impl Data for OwnedRepr { #[inline] fn into_owned(self_: ArrayBase) -> Array where @@ -376,7 +351,8 @@ unsafe impl Data for OwnedRepr #[inline] fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension + where + D: Dimension, { Ok(self_) } @@ -385,10 +361,10 @@ unsafe impl Data for OwnedRepr unsafe impl DataMut for OwnedRepr {} unsafe impl RawDataClone for OwnedRepr -where A: Clone +where + A: Clone, { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { let mut u = self.clone(); let mut new_ptr = u.as_nonnull_mut(); if size_of::() != 0 { @@ -398,8 +374,7 @@ where A: Clone (u, new_ptr) } - unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull - { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull { let our_off = if size_of::() != 0 { (ptr.as_ptr() as isize - other.as_ptr() as isize) / mem::size_of::() as isize } else { @@ -410,21 +385,18 @@ where A: Clone } } -unsafe impl RawData for ViewRepr<&A> -{ +unsafe impl RawData for ViewRepr<&A> { type Elem = A; #[inline(always)] - fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool { true } private_impl! {} } -unsafe impl Data for ViewRepr<&A> -{ +unsafe impl Data for ViewRepr<&A> { fn into_owned(self_: ArrayBase) -> Array where Self::Elem: Clone, @@ -434,35 +406,31 @@ unsafe impl Data for ViewRepr<&A> } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension + where + D: Dimension, { Err(self_) } } -unsafe impl RawDataClone for ViewRepr<&A> -{ - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { +unsafe impl RawDataClone for ViewRepr<&A> { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { (*self, ptr) } } -unsafe impl RawData for ViewRepr<&mut A> -{ +unsafe impl RawData for ViewRepr<&mut A> { type Elem = A; #[inline(always)] - fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool { true } private_impl! {} } -unsafe impl RawDataMut for ViewRepr<&mut A> -{ +unsafe impl RawDataMut for ViewRepr<&mut A> { #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -472,14 +440,12 @@ unsafe impl RawDataMut for ViewRepr<&mut A> } #[inline] - fn try_is_unique(&mut self) -> Option - { + fn try_is_unique(&mut self) -> Option { Some(true) } } -unsafe impl Data for ViewRepr<&mut A> -{ +unsafe impl Data for ViewRepr<&mut A> { fn into_owned(self_: ArrayBase) -> Array where Self::Elem: Clone, @@ -489,7 +455,8 @@ unsafe impl Data for ViewRepr<&mut A> } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension + where + D: Dimension, { Err(self_) } @@ -510,8 +477,7 @@ unsafe impl DataMut for ViewRepr<&mut A> {} // unsharing storage before mutating it. The initially allocated storage must be mutable so // that it can be mutated directly - through .raw_view_mut_unchecked() - for initialization. #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait DataOwned: Data -{ +pub unsafe trait DataOwned: Data { /// Corresponding owned data with MaybeUninit elements type MaybeUninit: DataOwned> + RawDataSubst; #[doc(hidden)] @@ -538,12 +504,10 @@ pub unsafe trait DataShared: Clone + Data + RawDataClone {} unsafe impl DataShared for OwnedArcRepr {} unsafe impl DataShared for ViewRepr<&A> {} -unsafe impl DataOwned for OwnedRepr -{ +unsafe impl DataOwned for OwnedRepr { type MaybeUninit = OwnedRepr>; - fn new(elements: Vec) -> Self - { + fn new(elements: Vec) -> Self { OwnedRepr::from(elements) } @@ -556,12 +520,10 @@ unsafe impl DataOwned for OwnedRepr } } -unsafe impl DataOwned for OwnedArcRepr -{ +unsafe impl DataOwned for OwnedArcRepr { type MaybeUninit = OwnedArcRepr>; - fn new(elements: Vec) -> Self - { + fn new(elements: Vec) -> Self { OwnedArcRepr(Arc::new(OwnedRepr::from(elements))) } @@ -574,13 +536,11 @@ unsafe impl DataOwned for OwnedArcRepr } } -unsafe impl RawData for CowRepr<'_, A> -{ +unsafe impl RawData for CowRepr<'_, A> { type Elem = A; #[inline] - fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool - { + fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool { match self { CowRepr::View(view) => view._is_pointer_inbounds(ptr), CowRepr::Owned(data) => data._is_pointer_inbounds(ptr), @@ -591,7 +551,8 @@ unsafe impl RawData for CowRepr<'_, A> } unsafe impl RawDataMut for CowRepr<'_, A> -where A: Clone +where + A: Clone, { #[inline] fn try_ensure_unique(array: &mut ArrayBase) @@ -612,17 +573,16 @@ where A: Clone } #[inline] - fn try_is_unique(&mut self) -> Option - { + fn try_is_unique(&mut self) -> Option { Some(self.is_owned()) } } unsafe impl RawDataClone for CowRepr<'_, A> -where A: Clone +where + A: Clone, { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) - { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { match self { CowRepr::View(view) => { let (new_view, ptr) = view.clone_with_ptr(ptr); @@ -635,8 +595,7 @@ where A: Clone } } - unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull - { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull { match (&mut *self, other) { (CowRepr::View(self_), CowRepr::View(other)) => self_.clone_from_with_ptr(other, ptr), (CowRepr::Owned(self_), CowRepr::Owned(other)) => self_.clone_from_with_ptr(other, ptr), @@ -654,8 +613,7 @@ where A: Clone } } -unsafe impl<'a, A> Data for CowRepr<'a, A> -{ +unsafe impl<'a, A> Data for CowRepr<'a, A> { #[inline] fn into_owned(self_: ArrayBase, D>) -> Array where @@ -672,7 +630,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> - where D: Dimension + where + D: Dimension, { match self_.data { CowRepr::View(_) => Err(self_), @@ -687,12 +646,10 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> unsafe impl DataMut for CowRepr<'_, A> where A: Clone {} -unsafe impl<'a, A> DataOwned for CowRepr<'a, A> -{ +unsafe impl<'a, A> DataOwned for CowRepr<'a, A> { type MaybeUninit = CowRepr<'a, MaybeUninit>; - fn new(elements: Vec) -> Self - { + fn new(elements: Vec) -> Self { CowRepr::Owned(OwnedRepr::new(elements)) } @@ -711,8 +668,7 @@ unsafe impl<'a, A> DataOwned for CowRepr<'a, A> /// keeping the same kind of storage. /// /// For example, `RawDataSubst` can map the type `OwnedRepr` to `OwnedRepr`. -pub trait RawDataSubst: RawData -{ +pub trait RawDataSubst: RawData { /// The resulting array storage of the same kind but substituted element type type Output: RawData; @@ -725,72 +681,58 @@ pub trait RawDataSubst: RawData unsafe fn data_subst(self) -> Self::Output; } -impl RawDataSubst for OwnedRepr -{ +impl RawDataSubst for OwnedRepr { type Output = OwnedRepr; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { self.data_subst() } } -impl RawDataSubst for OwnedArcRepr -{ +impl RawDataSubst for OwnedArcRepr { type Output = OwnedArcRepr; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { OwnedArcRepr(Arc::from_raw(Arc::into_raw(self.0) as *const OwnedRepr)) } } -impl RawDataSubst for RawViewRepr<*const A> -{ +impl RawDataSubst for RawViewRepr<*const A> { type Output = RawViewRepr<*const B>; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { RawViewRepr::new() } } -impl RawDataSubst for RawViewRepr<*mut A> -{ +impl RawDataSubst for RawViewRepr<*mut A> { type Output = RawViewRepr<*mut B>; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { RawViewRepr::new() } } -impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a A> -{ +impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a A> { type Output = ViewRepr<&'a B>; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { ViewRepr::new() } } -impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> -{ +impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> { type Output = ViewRepr<&'a mut B>; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { ViewRepr::new() } } -impl<'a, A: 'a, B: 'a> RawDataSubst for CowRepr<'a, A> -{ +impl<'a, A: 'a, B: 'a> RawDataSubst for CowRepr<'a, A> { type Output = CowRepr<'a, B>; - unsafe fn data_subst(self) -> Self::Output - { + unsafe fn data_subst(self) -> Self::Output { match self { CowRepr::View(view) => CowRepr::View(view.data_subst()), CowRepr::Owned(owned) => CowRepr::Owned(owned.data_subst()), diff --git a/src/dimension/axes.rs b/src/dimension/axes.rs index c7aaff14..629199d6 100644 --- a/src/dimension/axes.rs +++ b/src/dimension/axes.rs @@ -2,7 +2,8 @@ use crate::{Axis, Dimension, Ixs}; /// Create a new Axes iterator pub(crate) fn axes_of<'a, D>(d: &'a D, strides: &'a D) -> Axes<'a, D> -where D: Dimension +where + D: Dimension, { Axes { dim: d, @@ -37,8 +38,7 @@ where D: Dimension /// assert_eq!(largest_axis.len, 5); /// ``` #[derive(Debug)] -pub struct Axes<'a, D> -{ +pub struct Axes<'a, D> { dim: &'a D, strides: &'a D, start: usize, @@ -47,8 +47,7 @@ pub struct Axes<'a, D> /// Description of the axis, its length and its stride. #[derive(Debug)] -pub struct AxisDescription -{ +pub struct AxisDescription { /// Axis identifier (index) pub axis: Axis, /// Length in count of elements of the current axis @@ -61,13 +60,13 @@ copy_and_clone!(AxisDescription); copy_and_clone!(['a, D] Axes<'a, D>); impl Iterator for Axes<'_, D> -where D: Dimension +where + D: Dimension, { /// Description of the axis, its length and its stride. type Item = AxisDescription; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if self.start < self.end { let i = self.start.post_inc(); Some(AxisDescription { @@ -81,7 +80,8 @@ where D: Dimension } fn fold(self, init: B, f: F) -> B - where F: FnMut(B, AxisDescription) -> B + where + F: FnMut(B, AxisDescription) -> B, { (self.start..self.end) .map(move |i| AxisDescription { @@ -92,18 +92,17 @@ where D: Dimension .fold(init, f) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let len = self.end - self.start; (len, Some(len)) } } impl DoubleEndedIterator for Axes<'_, D> -where D: Dimension +where + D: Dimension, { - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { if self.start < self.end { let i = self.end.pre_dec(); Some(AxisDescription { @@ -117,24 +116,20 @@ where D: Dimension } } -trait IncOps: Copy -{ +trait IncOps: Copy { fn post_inc(&mut self) -> Self; fn pre_dec(&mut self) -> Self; } -impl IncOps for usize -{ +impl IncOps for usize { #[inline(always)] - fn post_inc(&mut self) -> Self - { + fn post_inc(&mut self) -> Self { let x = *self; *self += 1; x } #[inline(always)] - fn pre_dec(&mut self) -> Self - { + fn pre_dec(&mut self) -> Self { *self -= 1; *self } diff --git a/src/dimension/axis.rs b/src/dimension/axis.rs index 8c896f6b..611c62b3 100644 --- a/src/dimension/axis.rs +++ b/src/dimension/axis.rs @@ -26,12 +26,10 @@ #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct Axis(pub usize); -impl Axis -{ +impl Axis { /// Return the index of the axis. #[inline(always)] - pub fn index(self) -> usize - { + pub fn index(self) -> usize { self.0 } } diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index 8fb445cb..a6dd4328 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -41,8 +41,7 @@ where /// This trait is what determines that typing. /// /// For example, `Ix1: DimMax`, but not vice-versa. -pub trait DimMax -{ +pub trait DimMax { /// The resulting dimension type after broadcasting. type Output: Dimension; } @@ -50,8 +49,7 @@ pub trait DimMax /// Dimensions of the same type remain unchanged when co_broadcast. /// So you can directly use `D` as the resulting type. /// (Instead of `>::BroadcastOutput`) -impl DimMax for D -{ +impl DimMax for D { type Output = D; } @@ -98,14 +96,12 @@ impl_broadcast_distinct_fixed!(Ix6, IxDyn); #[cfg(test)] #[cfg(feature = "std")] -mod tests -{ +mod tests { use super::co_broadcast; use crate::{Dim, DimMax, Dimension, ErrorKind, Ix0, IxDynImpl, ShapeError}; #[test] - fn test_broadcast_shape() - { + fn test_broadcast_shape() { fn test_co(d1: &D1, d2: &D2, r: Result<>::Output, ShapeError>) where D1: Dimension + DimMax, diff --git a/src/dimension/conversion.rs b/src/dimension/conversion.rs index cee8a2eb..e241d259 100644 --- a/src/dimension/conversion.rs +++ b/src/dimension/conversion.rs @@ -40,8 +40,7 @@ macro_rules! index_item { } /// Argument conversion a dimension. -pub trait IntoDimension -{ +pub trait IntoDimension { /// The concrete type of the resultant dimension. type Dim: Dimension; @@ -49,49 +48,42 @@ pub trait IntoDimension fn into_dimension(self) -> Self::Dim; } -impl IntoDimension for Ix -{ +impl IntoDimension for Ix { type Dim = Ix1; #[inline(always)] - fn into_dimension(self) -> Ix1 - { + fn into_dimension(self) -> Ix1 { Ix1(self) } } impl IntoDimension for D -where D: Dimension +where + D: Dimension, { type Dim = D; #[inline(always)] - fn into_dimension(self) -> Self - { + fn into_dimension(self) -> Self { self } } -impl IntoDimension for IxDynImpl -{ +impl IntoDimension for IxDynImpl { type Dim = IxDyn; #[inline(always)] - fn into_dimension(self) -> Self::Dim - { + fn into_dimension(self) -> Self::Dim { Dim::new(self) } } -impl IntoDimension for Vec -{ +impl IntoDimension for Vec { type Dim = IxDyn; #[inline(always)] - fn into_dimension(self) -> Self::Dim - { + fn into_dimension(self) -> Self::Dim { Dim::new(IxDynImpl::from(self)) } } -pub trait Convert -{ +pub trait Convert { type To; fn convert(self) -> Self::To; } diff --git a/src/dimension/dim.rs b/src/dimension/dim.rs index 96e433bb..411430e4 100644 --- a/src/dimension/dim.rs +++ b/src/dimension/dim.rs @@ -35,26 +35,21 @@ use std::fmt; /// assert_eq!(array.raw_dim(), Dim([3, 2])); /// ``` #[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] -pub struct Dim -{ +pub struct Dim { index: I, } -impl Dim -{ +impl Dim { /// Private constructor and accessors for Dim - pub(crate) const fn new(index: I) -> Dim - { + pub(crate) const fn new(index: I) -> Dim { Dim { index } } #[inline(always)] - pub(crate) fn ix(&self) -> &I - { + pub(crate) fn ix(&self) -> &I { &self.index } #[inline(always)] - pub(crate) fn ixm(&mut self) -> &mut I - { + pub(crate) fn ixm(&mut self) -> &mut I { &mut self.index } } @@ -62,25 +57,26 @@ impl Dim /// Create a new dimension value. #[allow(non_snake_case)] pub fn Dim(index: T) -> T::Dim -where T: IntoDimension +where + T: IntoDimension, { index.into_dimension() } impl PartialEq for Dim -where I: PartialEq +where + I: PartialEq, { - fn eq(&self, rhs: &I) -> bool - { + fn eq(&self, rhs: &I) -> bool { self.index == *rhs } } impl fmt::Debug for Dim -where I: fmt::Debug +where + I: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.index) } } diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 373edb35..2bef009d 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -83,14 +83,12 @@ pub trait Dimension: fn into_pattern(self) -> Self::Pattern; /// Compute the size of the dimension (number of elements) - fn size(&self) -> usize - { + fn size(&self) -> usize { self.slice().iter().product() } /// Compute the size while checking for overflow. - fn size_checked(&self) -> Option - { + fn size_checked(&self) -> Option { self.slice() .iter() .try_fold(1_usize, |s, &a| s.checked_mul(a)) @@ -103,20 +101,17 @@ pub trait Dimension: fn slice_mut(&mut self) -> &mut [Ix]; /// Borrow as a read-only array view. - fn as_array_view(&self) -> ArrayView1<'_, Ix> - { + fn as_array_view(&self) -> ArrayView1<'_, Ix> { ArrayView1::from(self.slice()) } /// Borrow as a read-write array view. - fn as_array_view_mut(&mut self) -> ArrayViewMut1<'_, Ix> - { + fn as_array_view_mut(&mut self) -> ArrayViewMut1<'_, Ix> { ArrayViewMut1::from(self.slice_mut()) } #[doc(hidden)] - fn equal(&self, rhs: &Self) -> bool - { + fn equal(&self, rhs: &Self) -> bool { self.slice() == rhs.slice() } @@ -125,8 +120,7 @@ pub trait Dimension: /// If the array is non-empty, the strides result in contiguous layout; if /// the array is empty, the strides are all zeros. #[doc(hidden)] - fn default_strides(&self) -> Self - { + fn default_strides(&self) -> Self { // Compute default array strides // Shape (a, b, c) => Give strides (b * c, c, 1) let mut strides = Self::zeros(self.ndim()); @@ -151,8 +145,7 @@ pub trait Dimension: /// If the array is non-empty, the strides result in contiguous layout; if /// the array is empty, the strides are all zeros. #[doc(hidden)] - fn fortran_strides(&self) -> Self - { + fn fortran_strides(&self) -> Self { // Compute fortran array strides // Shape (a, b, c) => Give strides (1, a, a * b) let mut strides = Self::zeros(self.ndim()); @@ -182,8 +175,7 @@ pub trait Dimension: #[doc(hidden)] #[inline] - fn first_index(&self) -> Option - { + fn first_index(&self) -> Option { for ax in self.slice().iter() { if *ax == 0 { return None; @@ -197,8 +189,7 @@ pub trait Dimension: /// or None if there are no more. // FIXME: use &Self for index or even &mut? #[inline] - fn next_for(&self, mut index: Self) -> Option - { + fn next_for(&self, mut index: Self) -> Option { let mut done = false; for (&dim, ix) in zip(self.slice(), index.slice_mut()).rev() { *ix += 1; @@ -222,8 +213,7 @@ pub trait Dimension: /// /// Next in f-order #[inline] - fn next_for_f(&self, index: &mut Self) -> bool - { + fn next_for_f(&self, index: &mut Self) -> bool { let mut end_iteration = true; for (&dim, ix) in zip(self.slice(), index.slice_mut()) { *ix += 1; @@ -246,7 +236,8 @@ pub trait Dimension: /// Note: Returns `false` if any of the ndims don't match. #[doc(hidden)] fn strides_equivalent(&self, strides1: &Self, strides2: &D) -> bool - where D: Dimension + where + D: Dimension, { let shape_ndim = self.ndim(); shape_ndim == strides1.ndim() @@ -257,8 +248,7 @@ pub trait Dimension: #[doc(hidden)] /// Return stride offset for index. - fn stride_offset(index: &Self, strides: &Self) -> isize - { + fn stride_offset(index: &Self, strides: &Self) -> isize { let mut offset = 0; for (&i, &s) in izip!(index.slice(), strides.slice()) { offset += stride_offset(i, s); @@ -268,14 +258,12 @@ pub trait Dimension: #[doc(hidden)] /// Return stride offset for this dimension and index. - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option - { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { stride_offset_checked(self.slice(), strides.slice(), index.slice()) } #[doc(hidden)] - fn last_elem(&self) -> usize - { + fn last_elem(&self) -> usize { if self.ndim() == 0 { 0 } else { @@ -284,15 +272,13 @@ pub trait Dimension: } #[doc(hidden)] - fn set_last_elem(&mut self, i: usize) - { + fn set_last_elem(&mut self, i: usize) { let nd = self.ndim(); self.slice_mut()[nd - 1] = i; } #[doc(hidden)] - fn is_contiguous(dim: &Self, strides: &Self) -> bool - { + fn is_contiguous(dim: &Self, strides: &Self) -> bool { let defaults = dim.default_strides(); if strides.equal(&defaults) { return true; @@ -324,8 +310,7 @@ pub trait Dimension: /// /// Assumes that no stride value appears twice. #[doc(hidden)] - fn _fastest_varying_stride_order(&self) -> Self - { + fn _fastest_varying_stride_order(&self) -> Self { let mut indices = self.clone(); for (i, elt) in enumerate(indices.slice_mut()) { *elt = i; @@ -340,8 +325,7 @@ pub trait Dimension: /// Compute the minimum stride axis (absolute value), under the constraint /// that the length of the axis is > 1; #[doc(hidden)] - fn min_stride_axis(&self, strides: &Self) -> Axis - { + fn min_stride_axis(&self, strides: &Self) -> Axis { let n = match self.ndim() { 0 => panic!("min_stride_axis: Array must have ndim > 0"), 1 => return Axis(0), @@ -356,8 +340,7 @@ pub trait Dimension: /// Compute the maximum stride axis (absolute value), under the constraint /// that the length of the axis is > 1; #[doc(hidden)] - fn max_stride_axis(&self, strides: &Self) -> Axis - { + fn max_stride_axis(&self, strides: &Self) -> Axis { match self.ndim() { 0 => panic!("max_stride_axis: Array must have ndim > 0"), 1 => return Axis(0), @@ -370,14 +353,12 @@ pub trait Dimension: } /// Convert the dimensional into a dynamic dimensional (IxDyn). - fn into_dyn(self) -> IxDyn - { + fn into_dyn(self) -> IxDyn { IxDyn(self.slice()) } #[doc(hidden)] - fn from_dimension(d: &D2) -> Option - { + fn from_dimension(d: &D2) -> Option { let mut s = Self::default(); if s.ndim() == d.ndim() { for i in 0..d.ndim() { @@ -413,91 +394,76 @@ macro_rules! impl_insert_axis_array( ); ); -impl Dimension for Dim<[Ix; 0]> -{ +impl Dimension for Dim<[Ix; 0]> { const NDIM: Option = Some(0); type Pattern = (); type Smaller = Self; type Larger = Ix1; // empty product is 1 -> size is 1 #[inline] - fn ndim(&self) -> usize - { + fn ndim(&self) -> usize { 0 } #[inline] - fn slice(&self) -> &[Ix] - { + fn slice(&self) -> &[Ix] { &[] } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] - { + fn slice_mut(&mut self) -> &mut [Ix] { &mut [] } #[inline] - fn _fastest_varying_stride_order(&self) -> Self - { + fn _fastest_varying_stride_order(&self) -> Self { Ix0() } #[inline] fn into_pattern(self) -> Self::Pattern {} #[inline] - fn zeros(ndim: usize) -> Self - { + fn zeros(ndim: usize) -> Self { assert_eq!(ndim, 0); Self::default() } #[inline] - fn next_for(&self, _index: Self) -> Option - { + fn next_for(&self, _index: Self) -> Option { None } impl_insert_axis_array!(0); #[inline] - fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller - { + fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller { *self } private_impl! {} } -impl Dimension for Dim<[Ix; 1]> -{ +impl Dimension for Dim<[Ix; 1]> { const NDIM: Option = Some(1); type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; #[inline] - fn ndim(&self) -> usize - { + fn ndim(&self) -> usize { 1 } #[inline] - fn slice(&self) -> &[Ix] - { + fn slice(&self) -> &[Ix] { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] - { + fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() } #[inline] - fn into_pattern(self) -> Self::Pattern - { + fn into_pattern(self) -> Self::Pattern { get!(&self, 0) } #[inline] - fn zeros(ndim: usize) -> Self - { + fn zeros(ndim: usize) -> Self { assert_eq!(ndim, 1); Self::default() } #[inline] - fn next_for(&self, mut index: Self) -> Option - { + fn next_for(&self, mut index: Self) -> Option { getm!(index, 0) += 1; if get!(&index, 0) < get!(self, 0) { Some(index) @@ -507,25 +473,21 @@ impl Dimension for Dim<[Ix; 1]> } #[inline] - fn equal(&self, rhs: &Self) -> bool - { + fn equal(&self, rhs: &Self) -> bool { get!(self, 0) == get!(rhs, 0) } #[inline] - fn size(&self) -> usize - { + fn size(&self) -> usize { get!(self, 0) } #[inline] - fn size_checked(&self) -> Option - { + fn size_checked(&self) -> Option { Some(get!(self, 0)) } #[inline] - fn default_strides(&self) -> Self - { + fn default_strides(&self) -> Self { if get!(self, 0) == 0 { Ix1(0) } else { @@ -534,26 +496,22 @@ impl Dimension for Dim<[Ix; 1]> } #[inline] - fn _fastest_varying_stride_order(&self) -> Self - { + fn _fastest_varying_stride_order(&self) -> Self { Ix1(0) } #[inline(always)] - fn min_stride_axis(&self, _: &Self) -> Axis - { + fn min_stride_axis(&self, _: &Self) -> Axis { Axis(0) } #[inline(always)] - fn max_stride_axis(&self, _: &Self) -> Axis - { + fn max_stride_axis(&self, _: &Self) -> Axis { Axis(0) } #[inline] - fn first_index(&self) -> Option - { + fn first_index(&self) -> Option { if get!(self, 0) != 0 { Some(Ix1(0)) } else { @@ -563,15 +521,13 @@ impl Dimension for Dim<[Ix; 1]> /// Self is an index, return the stride offset #[inline(always)] - fn stride_offset(index: &Self, stride: &Self) -> isize - { + fn stride_offset(index: &Self, stride: &Self) -> isize { stride_offset(get!(index, 0), get!(stride, 0)) } /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, stride: &Self, index: &Self) -> Option - { + fn stride_offset_checked(&self, stride: &Self, index: &Self) -> Option { if get!(index, 0) < get!(self, 0) { Some(stride_offset(get!(index, 0), get!(stride, 0))) } else { @@ -580,13 +536,11 @@ impl Dimension for Dim<[Ix; 1]> } impl_insert_axis_array!(1); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller - { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } - fn from_dimension(d: &D2) -> Option - { + fn from_dimension(d: &D2) -> Option { if 1 == d.ndim() { Some(Ix1(d[0])) } else { @@ -596,41 +550,34 @@ impl Dimension for Dim<[Ix; 1]> private_impl! {} } -impl Dimension for Dim<[Ix; 2]> -{ +impl Dimension for Dim<[Ix; 2]> { const NDIM: Option = Some(2); type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; #[inline] - fn ndim(&self) -> usize - { + fn ndim(&self) -> usize { 2 } #[inline] - fn into_pattern(self) -> Self::Pattern - { + fn into_pattern(self) -> Self::Pattern { self.ix().convert() } #[inline] - fn slice(&self) -> &[Ix] - { + fn slice(&self) -> &[Ix] { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] - { + fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() } #[inline] - fn zeros(ndim: usize) -> Self - { + fn zeros(ndim: usize) -> Self { assert_eq!(ndim, 2); Self::default() } #[inline] - fn next_for(&self, index: Self) -> Option - { + fn next_for(&self, index: Self) -> Option { let mut i = get!(&index, 0); let mut j = get!(&index, 1); let imax = get!(self, 0); @@ -647,40 +594,34 @@ impl Dimension for Dim<[Ix; 2]> } #[inline] - fn equal(&self, rhs: &Self) -> bool - { + fn equal(&self, rhs: &Self) -> bool { get!(self, 0) == get!(rhs, 0) && get!(self, 1) == get!(rhs, 1) } #[inline] - fn size(&self) -> usize - { + fn size(&self) -> usize { get!(self, 0) * get!(self, 1) } #[inline] - fn size_checked(&self) -> Option - { + fn size_checked(&self) -> Option { let m = get!(self, 0); let n = get!(self, 1); m.checked_mul(n) } #[inline] - fn last_elem(&self) -> usize - { + fn last_elem(&self) -> usize { get!(self, 1) } #[inline] - fn set_last_elem(&mut self, i: usize) - { + fn set_last_elem(&mut self, i: usize) { getm!(self, 1) = i; } #[inline] - fn default_strides(&self) -> Self - { + fn default_strides(&self) -> Self { let m = get!(self, 0); let n = get!(self, 1); if m == 0 || n == 0 { @@ -690,8 +631,7 @@ impl Dimension for Dim<[Ix; 2]> } } #[inline] - fn fortran_strides(&self) -> Self - { + fn fortran_strides(&self) -> Self { let m = get!(self, 0); let n = get!(self, 1); if m == 0 || n == 0 { @@ -702,8 +642,7 @@ impl Dimension for Dim<[Ix; 2]> } #[inline] - fn _fastest_varying_stride_order(&self) -> Self - { + fn _fastest_varying_stride_order(&self) -> Self { if (get!(self, 0) as Ixs).abs() <= (get!(self, 1) as Ixs).abs() { Ix2(0, 1) } else { @@ -712,8 +651,7 @@ impl Dimension for Dim<[Ix; 2]> } #[inline] - fn min_stride_axis(&self, strides: &Self) -> Axis - { + fn min_stride_axis(&self, strides: &Self) -> Axis { let s = get!(strides, 0) as Ixs; let t = get!(strides, 1) as Ixs; if s.abs() < t.abs() { @@ -724,8 +662,7 @@ impl Dimension for Dim<[Ix; 2]> } #[inline] - fn first_index(&self) -> Option - { + fn first_index(&self) -> Option { let m = get!(self, 0); let n = get!(self, 1); if m != 0 && n != 0 { @@ -737,8 +674,7 @@ impl Dimension for Dim<[Ix; 2]> /// Self is an index, return the stride offset #[inline(always)] - fn stride_offset(index: &Self, strides: &Self) -> isize - { + fn stride_offset(index: &Self, strides: &Self) -> isize { let i = get!(index, 0); let j = get!(index, 1); let s = get!(strides, 0); @@ -748,8 +684,7 @@ impl Dimension for Dim<[Ix; 2]> /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option - { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { let m = get!(self, 0); let n = get!(self, 1); let i = get!(index, 0); @@ -764,43 +699,36 @@ impl Dimension for Dim<[Ix; 2]> } impl_insert_axis_array!(2); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller - { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } private_impl! {} } -impl Dimension for Dim<[Ix; 3]> -{ +impl Dimension for Dim<[Ix; 3]> { const NDIM: Option = Some(3); type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; #[inline] - fn ndim(&self) -> usize - { + fn ndim(&self) -> usize { 3 } #[inline] - fn into_pattern(self) -> Self::Pattern - { + fn into_pattern(self) -> Self::Pattern { self.ix().convert() } #[inline] - fn slice(&self) -> &[Ix] - { + fn slice(&self) -> &[Ix] { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] - { + fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() } #[inline] - fn size(&self) -> usize - { + fn size(&self) -> usize { let m = get!(self, 0); let n = get!(self, 1); let o = get!(self, 2); @@ -808,15 +736,13 @@ impl Dimension for Dim<[Ix; 3]> } #[inline] - fn zeros(ndim: usize) -> Self - { + fn zeros(ndim: usize) -> Self { assert_eq!(ndim, 3); Self::default() } #[inline] - fn next_for(&self, index: Self) -> Option - { + fn next_for(&self, index: Self) -> Option { let mut i = get!(&index, 0); let mut j = get!(&index, 1); let mut k = get!(&index, 2); @@ -840,8 +766,7 @@ impl Dimension for Dim<[Ix; 3]> /// Self is an index, return the stride offset #[inline] - fn stride_offset(index: &Self, strides: &Self) -> isize - { + fn stride_offset(index: &Self, strides: &Self) -> isize { let i = get!(index, 0); let j = get!(index, 1); let k = get!(index, 2); @@ -853,8 +778,7 @@ impl Dimension for Dim<[Ix; 3]> /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option - { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { let m = get!(self, 0); let n = get!(self, 1); let l = get!(self, 2); @@ -872,8 +796,7 @@ impl Dimension for Dim<[Ix; 3]> } #[inline] - fn _fastest_varying_stride_order(&self) -> Self - { + fn _fastest_varying_stride_order(&self) -> Self { let mut stride = *self; let mut order = Ix3(0, 1, 2); macro_rules! swap { @@ -895,8 +818,7 @@ impl Dimension for Dim<[Ix; 3]> } impl_insert_axis_array!(3); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller - { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } private_impl! {} @@ -953,49 +875,41 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, { /// IxDyn is a "dynamic" index, pretty hard to use when indexing, /// and memory wasteful, but it allows an arbitrary and dynamic number of axes. -impl Dimension for IxDyn -{ +impl Dimension for IxDyn { const NDIM: Option = None; type Pattern = Self; type Smaller = Self; type Larger = Self; #[inline] - fn ndim(&self) -> usize - { + fn ndim(&self) -> usize { self.ix().len() } #[inline] - fn slice(&self) -> &[Ix] - { + fn slice(&self) -> &[Ix] { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] - { + fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() } #[inline] - fn into_pattern(self) -> Self::Pattern - { + fn into_pattern(self) -> Self::Pattern { self } #[inline] - fn zeros(ndim: usize) -> Self - { + fn zeros(ndim: usize) -> Self { IxDyn::zeros(ndim) } #[inline] - fn insert_axis(&self, axis: Axis) -> Self::Larger - { + fn insert_axis(&self, axis: Axis) -> Self::Larger { debug_assert!(axis.index() <= self.ndim()); Dim::new(self.ix().insert(axis.index())) } #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller - { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { if self.ndim() > 0 { self.remove_axis(axis) } else { @@ -1003,32 +917,26 @@ impl Dimension for IxDyn } } - fn from_dimension(d: &D2) -> Option - { + fn from_dimension(d: &D2) -> Option { Some(IxDyn(d.slice())) } - fn into_dyn(self) -> IxDyn - { + fn into_dyn(self) -> IxDyn { self } private_impl! {} } -impl Index for Dim -{ +impl Index for Dim { type Output = >::Output; - fn index(&self, index: usize) -> &Self::Output - { + fn index(&self, index: usize) -> &Self::Output { &self.ix()[index] } } -impl IndexMut for Dim -{ - fn index_mut(&mut self, index: usize) -> &mut Self::Output - { +impl IndexMut for Dim { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.ixm()[index] } } diff --git a/src/dimension/dynindeximpl.rs b/src/dimension/dynindeximpl.rs index 60aeacd8..cd30abc6 100644 --- a/src/dimension/dynindeximpl.rs +++ b/src/dimension/dynindeximpl.rs @@ -10,17 +10,14 @@ const CAP: usize = 4; /// T is usize or isize #[derive(Debug)] -enum IxDynRepr -{ +enum IxDynRepr { Inline(u32, [T; CAP]), Alloc(Box<[T]>), } -impl Deref for IxDynRepr -{ +impl Deref for IxDynRepr { type Target = [T]; - fn deref(&self) -> &[T] - { + fn deref(&self) -> &[T] { match *self { IxDynRepr::Inline(len, ref ar) => { debug_assert!(len as usize <= ar.len()); @@ -31,10 +28,8 @@ impl Deref for IxDynRepr } } -impl DerefMut for IxDynRepr -{ - fn deref_mut(&mut self) -> &mut [T] - { +impl DerefMut for IxDynRepr { + fn deref_mut(&mut self) -> &mut [T] { match *self { IxDynRepr::Inline(len, ref mut ar) => { debug_assert!(len as usize <= ar.len()); @@ -46,20 +41,16 @@ impl DerefMut for IxDynRepr } /// The default is equivalent to `Self::from(&[0])`. -impl Default for IxDynRepr -{ - fn default() -> Self - { +impl Default for IxDynRepr { + fn default() -> Self { Self::copy_from(&[0]) } } use num_traits::Zero; -impl IxDynRepr -{ - pub fn copy_from(x: &[T]) -> Self - { +impl IxDynRepr { + pub fn copy_from(x: &[T]) -> Self { if x.len() <= CAP { let mut arr = [T::zero(); CAP]; arr[..x.len()].copy_from_slice(x); @@ -70,11 +61,9 @@ impl IxDynRepr } } -impl IxDynRepr -{ +impl IxDynRepr { // make an Inline or Alloc version as appropriate - fn from_vec_auto(v: Vec) -> Self - { + fn from_vec_auto(v: Vec) -> Self { if v.len() <= CAP { Self::copy_from(&v) } else { @@ -83,23 +72,18 @@ impl IxDynRepr } } -impl IxDynRepr -{ - fn from_vec(v: Vec) -> Self - { +impl IxDynRepr { + fn from_vec(v: Vec) -> Self { IxDynRepr::Alloc(v.into_boxed_slice()) } - fn from(x: &[T]) -> Self - { + fn from(x: &[T]) -> Self { Self::from_vec(x.to_vec()) } } -impl Clone for IxDynRepr -{ - fn clone(&self) -> Self - { +impl Clone for IxDynRepr { + fn clone(&self) -> Self { match *self { IxDynRepr::Inline(len, arr) => IxDynRepr::Inline(len, arr), _ => Self::from(&self[..]), @@ -109,25 +93,22 @@ impl Clone for IxDynRepr impl Eq for IxDynRepr {} -impl PartialEq for IxDynRepr -{ - fn eq(&self, rhs: &Self) -> bool - { +impl PartialEq for IxDynRepr { + fn eq(&self, rhs: &Self) -> bool { match (self, rhs) { - (&IxDynRepr::Inline(slen, ref sarr), &IxDynRepr::Inline(rlen, ref rarr)) => + (&IxDynRepr::Inline(slen, ref sarr), &IxDynRepr::Inline(rlen, ref rarr)) => { slen == rlen && (0..CAP) .filter(|&i| i < slen as usize) - .all(|i| sarr[i] == rarr[i]), + .all(|i| sarr[i] == rarr[i]) + } _ => self[..] == rhs[..], } } } -impl Hash for IxDynRepr -{ - fn hash(&self, state: &mut H) - { +impl Hash for IxDynRepr { + fn hash(&self, state: &mut H) { Hash::hash(&self[..], state) } } @@ -140,10 +121,8 @@ impl Hash for IxDynRepr #[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct IxDynImpl(IxDynRepr); -impl IxDynImpl -{ - pub(crate) fn insert(&self, i: usize) -> Self - { +impl IxDynImpl { + pub(crate) fn insert(&self, i: usize) -> Self { let len = self.len(); debug_assert!(i <= len); IxDynImpl(if len < CAP { @@ -160,8 +139,7 @@ impl IxDynImpl }) } - fn remove(&self, i: usize) -> Self - { + fn remove(&self, i: usize) -> Self { IxDynImpl(match self.0 { IxDynRepr::Inline(0, _) => IxDynRepr::Inline(0, [0; CAP]), IxDynRepr::Inline(1, _) => IxDynRepr::Inline(0, [0; CAP]), @@ -182,88 +160,74 @@ impl IxDynImpl } } -impl<'a> From<&'a [Ix]> for IxDynImpl -{ +impl<'a> From<&'a [Ix]> for IxDynImpl { #[inline] - fn from(ix: &'a [Ix]) -> Self - { + fn from(ix: &'a [Ix]) -> Self { IxDynImpl(IxDynRepr::copy_from(ix)) } } -impl From> for IxDynImpl -{ +impl From> for IxDynImpl { #[inline] - fn from(ix: Vec) -> Self - { + fn from(ix: Vec) -> Self { IxDynImpl(IxDynRepr::from_vec_auto(ix)) } } impl Index for IxDynImpl -where [Ix]: Index +where + [Ix]: Index, { type Output = <[Ix] as Index>::Output; - fn index(&self, index: J) -> &Self::Output - { + fn index(&self, index: J) -> &Self::Output { &self.0[index] } } impl IndexMut for IxDynImpl -where [Ix]: IndexMut +where + [Ix]: IndexMut, { - fn index_mut(&mut self, index: J) -> &mut Self::Output - { + fn index_mut(&mut self, index: J) -> &mut Self::Output { &mut self.0[index] } } -impl Deref for IxDynImpl -{ +impl Deref for IxDynImpl { type Target = [Ix]; #[inline] - fn deref(&self) -> &[Ix] - { + fn deref(&self) -> &[Ix] { &self.0 } } -impl DerefMut for IxDynImpl -{ +impl DerefMut for IxDynImpl { #[inline] - fn deref_mut(&mut self) -> &mut [Ix] - { + fn deref_mut(&mut self) -> &mut [Ix] { &mut self.0 } } -impl<'a> IntoIterator for &'a IxDynImpl -{ +impl<'a> IntoIterator for &'a IxDynImpl { type Item = &'a Ix; type IntoIter = <&'a [Ix] as IntoIterator>::IntoIter; #[inline] - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { self[..].iter() } } -impl RemoveAxis for Dim -{ - fn remove_axis(&self, axis: Axis) -> Self - { +impl RemoveAxis for Dim { + fn remove_axis(&self, axis: Axis) -> Self { debug_assert!(axis.index() < self.ndim()); Dim::new(self.ix().remove(axis.index())) } } -impl IxDyn -{ +impl IxDyn { /// Create a new dimension value with `n` axes, all zeros #[inline] - pub fn zeros(n: usize) -> IxDyn - { + pub fn zeros(n: usize) -> IxDyn { const ZEROS: &[usize] = &[0; 4]; if n <= ZEROS.len() { Dim(&ZEROS[..n]) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 4731da1b..a0bf89bc 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -45,8 +45,7 @@ mod sequence; /// Calculate offset from `Ix` stride converting sign properly #[inline(always)] -pub fn stride_offset(n: Ix, stride: Ix) -> isize -{ +pub fn stride_offset(n: Ix, stride: Ix) -> isize { (n as isize) * (stride as Ixs) } @@ -55,8 +54,7 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize /// There is overlap if, when iterating through the dimensions in order of /// increasing stride, the current stride is less than or equal to the maximum /// possible offset along the preceding axes. (Axes of length ≤1 are ignored.) -pub(crate) fn dim_stride_overlap(dim: &D, strides: &D) -> bool -{ +pub(crate) fn dim_stride_overlap(dim: &D, strides: &D) -> bool { let order = strides._fastest_varying_stride_order(); let mut sum_prev_offsets = 0; for &index in order.slice() { @@ -85,8 +83,7 @@ pub(crate) fn dim_stride_overlap(dim: &D, strides: &D) -> bool /// are met to construct an array from the data buffer, `dim`, and `strides`. /// (The data buffer being a slice or `Vec` guarantees that it contains no more /// than `isize::MAX` bytes.) -pub fn size_of_shape_checked(dim: &D) -> Result -{ +pub fn size_of_shape_checked(dim: &D) -> Result { let size_nonzero = dim .slice() .iter() @@ -107,8 +104,7 @@ pub fn size_of_shape_checked(dim: &D) -> Result /// The strides must not allow any element to be referenced by two different indices. /// #[derive(Copy, Clone, PartialEq)] -pub(crate) enum CanIndexCheckMode -{ +pub(crate) enum CanIndexCheckMode { /// Owned or mutable: No aliasing OwnedMutable, /// Aliasing @@ -141,8 +137,7 @@ pub(crate) enum CanIndexCheckMode /// accessible by moving along all axes does not exceed `isize::MAX`. pub(crate) fn can_index_slice_with_strides( data: &[A], dim: &D, strides: &Strides, mode: CanIndexCheckMode, -) -> Result<(), ShapeError> -{ +) -> Result<(), ShapeError> { if let Strides::Custom(strides) = strides { can_index_slice(data, dim, strides, mode) } else { @@ -151,8 +146,7 @@ pub(crate) fn can_index_slice_with_strides( } } -pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) -> Result<(), ShapeError> -{ +pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) -> Result<(), ShapeError> { // Condition 1. let len = size_of_shape_checked(dim)?; // Condition 2. @@ -177,13 +171,15 @@ pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) /// also implies that the length of any individual axis does not exceed /// `isize::MAX`.) pub fn max_abs_offset_check_overflow(dim: &D, strides: &D) -> Result -where D: Dimension +where + D: Dimension, { max_abs_offset_check_overflow_impl(mem::size_of::(), dim, strides) } fn max_abs_offset_check_overflow_impl(elem_size: usize, dim: &D, strides: &D) -> Result -where D: Dimension +where + D: Dimension, { // Condition 1. if dim.ndim() != strides.ndim() { @@ -257,8 +253,7 @@ where D: Dimension /// negative strides are correctly handled.) pub(crate) fn can_index_slice( data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode, -) -> Result<(), ShapeError> -{ +) -> Result<(), ShapeError> { // Check conditions 1 and 2 and calculate `max_offset`. let max_offset = max_abs_offset_check_overflow::(dim, strides)?; can_index_slice_impl(max_offset, data.len(), dim, strides, mode) @@ -266,8 +261,7 @@ pub(crate) fn can_index_slice( fn can_index_slice_impl( max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode, -) -> Result<(), ShapeError> -{ +) -> Result<(), ShapeError> { // Check condition 3. let is_empty = dim.slice().contains(&0); if is_empty && max_offset > data_len { @@ -287,8 +281,7 @@ fn can_index_slice_impl( /// Stride offset checked general version (slices) #[inline] -pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option -{ +pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option { if index.len() != dim.len() { return None; } @@ -304,7 +297,8 @@ pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option /// Checks if strides are non-negative. pub fn strides_non_negative(strides: &D) -> Result<(), ShapeError> -where D: Dimension +where + D: Dimension, { for &stride in strides.slice() { if (stride as isize) < 0 { @@ -315,8 +309,7 @@ where D: Dimension } /// Implementation-specific extensions to `Dimension` -pub trait DimensionExt -{ +pub trait DimensionExt { // note: many extensions go in the main trait if they need to be special- // cased per dimension /// Get the dimension at `axis`. @@ -333,32 +326,28 @@ pub trait DimensionExt } impl DimensionExt for D -where D: Dimension +where + D: Dimension, { #[inline] - fn axis(&self, axis: Axis) -> Ix - { + fn axis(&self, axis: Axis) -> Ix { self[axis.index()] } #[inline] - fn set_axis(&mut self, axis: Axis, value: Ix) - { + fn set_axis(&mut self, axis: Axis, value: Ix) { self[axis.index()] = value; } } -impl DimensionExt for [Ix] -{ +impl DimensionExt for [Ix] { #[inline] - fn axis(&self, axis: Axis) -> Ix - { + fn axis(&self, axis: Axis) -> Ix { self[axis.index()] } #[inline] - fn set_axis(&mut self, axis: Axis, value: Ix) - { + fn set_axis(&mut self, axis: Axis, value: Ix) { self[axis.index()] = value; } } @@ -369,8 +358,7 @@ impl DimensionExt for [Ix] /// **Panics** if `index` is larger than the size of the axis #[track_caller] // FIXME: Move to Dimension trait -pub fn do_collapse_axis(dims: &mut D, strides: &D, axis: usize, index: usize) -> isize -{ +pub fn do_collapse_axis(dims: &mut D, strides: &D, axis: usize, index: usize) -> isize { let dim = dims.slice()[axis]; let stride = strides.slice()[axis]; ndassert!( @@ -387,8 +375,7 @@ pub fn do_collapse_axis(dims: &mut D, strides: &D, axis: usize, in /// Compute the equivalent unsigned index given the axis length and signed index. #[inline] -pub fn abs_index(len: Ix, index: Ixs) -> Ix -{ +pub fn abs_index(len: Ix, index: Ixs) -> Ix { if index < 0 { len - (-index as Ix) } else { @@ -402,34 +389,22 @@ pub fn abs_index(len: Ix, index: Ixs) -> Ix /// /// **Panics** if stride is 0 or if any index is out of bounds. #[track_caller] -fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) -{ +fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) { let Slice { start, end, step } = slice; let start = abs_index(axis_len, start); let mut end = abs_index(axis_len, end.unwrap_or(axis_len as isize)); if end < start { end = start; } - ndassert!( - start <= axis_len, - "Slice begin {} is past end of axis of length {}", - start, - axis_len, - ); - ndassert!( - end <= axis_len, - "Slice end {} is past end of axis of length {}", - end, - axis_len, - ); + ndassert!(start <= axis_len, "Slice begin {} is past end of axis of length {}", start, axis_len,); + ndassert!(end <= axis_len, "Slice end {} is past end of axis of length {}", end, axis_len,); ndassert!(step != 0, "Slice stride must not be zero"); (start, end, step) } /// This function computes the offset from the lowest address element to the /// logically first element. -pub fn offset_from_low_addr_ptr_to_logical_ptr(dim: &D, strides: &D) -> usize -{ +pub fn offset_from_low_addr_ptr_to_logical_ptr(dim: &D, strides: &D) -> usize { let offset = izip!(dim.slice(), strides.slice()).fold(0, |_offset, (&d, &s)| { let s = s as isize; if s < 0 && d > 1 { @@ -446,8 +421,7 @@ pub fn offset_from_low_addr_ptr_to_logical_ptr(dim: &D, strides: & /// /// **Panics** if stride is 0 or if any index is out of bounds. #[track_caller] -pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize -{ +pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { let (start, end, step) = to_abs_slice(*dim, slice); let m = end - start; @@ -500,8 +474,7 @@ pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize /// nonnegative. /// /// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm -fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) -{ +fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) { if a == 0 { (b.abs(), (0, b.signum())) } else if b == 0 { @@ -537,8 +510,7 @@ fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) /// /// See https://en.wikipedia.org/wiki/Diophantine_equation#One_equation /// and https://math.stackexchange.com/questions/1656120#1656138 -fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> -{ +fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> { debug_assert_ne!(a, 0); debug_assert_ne!(b, 0); let (g, (u, _)) = extended_gcd(a, b); @@ -556,8 +528,7 @@ fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, i /// consecutive elements (the sign is irrelevant). /// /// **Note** `step1` and `step2` must be nonzero. -fn arith_seq_intersect((min1, max1, step1): (isize, isize, isize), (min2, max2, step2): (isize, isize, isize)) -> bool -{ +fn arith_seq_intersect((min1, max1, step1): (isize, isize, isize), (min2, max2, step2): (isize, isize, isize)) -> bool { debug_assert!(max1 >= min1); debug_assert!(max2 >= min2); debug_assert_eq!((max1 - min1) % step1, 0); @@ -613,8 +584,7 @@ fn arith_seq_intersect((min1, max1, step1): (isize, isize, isize), (min2, max2, /// Returns the minimum and maximum values of the indices (inclusive). /// /// If the slice is empty, then returns `None`, otherwise returns `Some((min, max))`. -fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> -{ +fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { let (start, end, step) = to_abs_slice(axis_len, slice); if start == end { None @@ -626,8 +596,7 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> } /// Returns `true` iff the slices intersect. -pub fn slices_intersect(dim: &D, indices1: impl SliceArg, indices2: impl SliceArg) -> bool -{ +pub fn slices_intersect(dim: &D, indices1: impl SliceArg, indices2: impl SliceArg) -> bool { debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim()); for (&axis_len, &si1, &si2) in izip!( dim.slice(), @@ -684,8 +653,7 @@ pub fn slices_intersect(dim: &D, indices1: impl SliceArg, indic true } -pub(crate) fn is_layout_c(dim: &D, strides: &D) -> bool -{ +pub(crate) fn is_layout_c(dim: &D, strides: &D) -> bool { if let Some(1) = D::NDIM { return strides[0] == 1 || dim[0] <= 1; } @@ -710,8 +678,7 @@ pub(crate) fn is_layout_c(dim: &D, strides: &D) -> bool true } -pub(crate) fn is_layout_f(dim: &D, strides: &D) -> bool -{ +pub(crate) fn is_layout_f(dim: &D, strides: &D) -> bool { if let Some(1) = D::NDIM { return strides[0] == 1 || dim[0] <= 1; } @@ -737,7 +704,8 @@ pub(crate) fn is_layout_f(dim: &D, strides: &D) -> bool } pub fn merge_axes(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bool -where D: Dimension +where + D: Dimension, { let into_len = dim.axis(into); let into_stride = strides.axis(into) as isize; @@ -765,16 +733,18 @@ where D: Dimension /// Move the axis which has the smallest absolute stride and a length /// greater than one to be the last axis. pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) -where D: Dimension +where + D: Dimension, { debug_assert_eq!(dim.ndim(), strides.ndim()); match dim.ndim() { 0 | 1 => {} - 2 => + 2 => { if dim[1] <= 1 || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() { dim.slice_mut().swap(0, 1); strides.slice_mut().swap(0, 1); - }, + } + } n => { if let Some(min_stride_axis) = (0..n) .filter(|&ax| dim[ax] > 1) @@ -789,19 +759,10 @@ where D: Dimension } #[cfg(test)] -mod test -{ +mod test { use super::{ - arith_seq_intersect, - can_index_slice, - can_index_slice_not_custom, - extended_gcd, - max_abs_offset_check_overflow, - slice_min_max, - slices_intersect, - solve_linear_diophantine_eq, - CanIndexCheckMode, - IntoDimension, + arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd, max_abs_offset_check_overflow, + slice_min_max, slices_intersect, solve_linear_diophantine_eq, CanIndexCheckMode, IntoDimension, }; use crate::error::{from_kind, ErrorKind}; use crate::slice::Slice; @@ -810,8 +771,7 @@ mod test use quickcheck::{quickcheck, TestResult}; #[test] - fn slice_indexing_uncommon_strides() - { + fn slice_indexing_uncommon_strides() { let v: alloc::vec::Vec<_> = (0..12).collect(); let dim = (2, 3, 2).into_dimension(); let strides = (1, 2, 6).into_dimension(); @@ -825,8 +785,7 @@ mod test } #[test] - fn overlapping_strides_dim() - { + fn overlapping_strides_dim() { let dim = (2, 3, 2).into_dimension(); let strides = (5, 2, 1).into_dimension(); assert!(super::dim_stride_overlap(&dim, &strides)); @@ -848,8 +807,7 @@ mod test } #[test] - fn max_abs_offset_check_overflow_examples() - { + fn max_abs_offset_check_overflow_examples() { let dim = (1, isize::MAX as usize, 1).into_dimension(); let strides = (1, 1, 1).into_dimension(); max_abs_offset_check_overflow::(&dim, &strides).unwrap(); @@ -865,15 +823,13 @@ mod test } #[test] - fn can_index_slice_ix0() - { + fn can_index_slice_ix0() { can_index_slice::(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap(); can_index_slice::(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err(); } #[test] - fn can_index_slice_ix1() - { + fn can_index_slice_ix1() { let mode = CanIndexCheckMode::OwnedMutable; can_index_slice::(&[], &Ix1(0), &Ix1(0), mode).unwrap(); can_index_slice::(&[], &Ix1(0), &Ix1(1), mode).unwrap(); @@ -889,8 +845,7 @@ mod test } #[test] - fn can_index_slice_ix2() - { + fn can_index_slice_ix2() { let mode = CanIndexCheckMode::OwnedMutable; can_index_slice::(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap(); can_index_slice::(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap(); @@ -910,8 +865,7 @@ mod test } #[test] - fn can_index_slice_ix3() - { + fn can_index_slice_ix3() { let mode = CanIndexCheckMode::OwnedMutable; can_index_slice::(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap(); can_index_slice::(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err(); @@ -921,8 +875,7 @@ mod test } #[test] - fn can_index_slice_zero_size_elem() - { + fn can_index_slice_zero_size_elem() { let mode = CanIndexCheckMode::OwnedMutable; can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap(); can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap(); @@ -973,8 +926,7 @@ mod test } #[test] - fn extended_gcd_zero() - { + fn extended_gcd_zero() { assert_eq!(extended_gcd(0, 0), (0, (0, 0))); assert_eq!(extended_gcd(0, 5), (5, (0, 1))); assert_eq!(extended_gcd(5, 0), (5, (1, 0))); @@ -1065,8 +1017,7 @@ mod test } #[test] - fn slice_min_max_empty() - { + fn slice_min_max_empty() { assert_eq!(slice_min_max(0, Slice::new(0, None, 3)), None); assert_eq!(slice_min_max(10, Slice::new(1, Some(1), 3)), None); assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), 3)), None); @@ -1075,8 +1026,7 @@ mod test } #[test] - fn slice_min_max_pos_step() - { + fn slice_min_max_pos_step() { assert_eq!(slice_min_max(10, Slice::new(1, Some(8), 3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(1, Some(9), 3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), 3)), Some((1, 7))); @@ -1092,22 +1042,15 @@ mod test } #[test] - fn slice_min_max_neg_step() - { + fn slice_min_max_neg_step() { assert_eq!(slice_min_max(10, Slice::new(1, Some(8), -3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(2, Some(8), -3)), Some((4, 7))); assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), -3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(-8, Some(8), -3)), Some((4, 7))); assert_eq!(slice_min_max(10, Slice::new(1, Some(-2), -3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(2, Some(-2), -3)), Some((4, 7))); - assert_eq!( - slice_min_max(10, Slice::new(-9, Some(-2), -3)), - Some((1, 7)) - ); - assert_eq!( - slice_min_max(10, Slice::new(-8, Some(-2), -3)), - Some((4, 7)) - ); + assert_eq!(slice_min_max(10, Slice::new(-9, Some(-2), -3)), Some((1, 7))); + assert_eq!(slice_min_max(10, Slice::new(-8, Some(-2), -3)), Some((4, 7))); assert_eq!(slice_min_max(9, Slice::new(2, None, -3)), Some((2, 8))); assert_eq!(slice_min_max(9, Slice::new(-7, None, -3)), Some((2, 8))); assert_eq!(slice_min_max(9, Slice::new(3, None, -3)), Some((5, 8))); @@ -1115,48 +1058,18 @@ mod test } #[test] - fn slices_intersect_true() - { - assert!(slices_intersect( - &Dim([4, 5]), - s![NewAxis, .., NewAxis, ..], - s![.., NewAxis, .., NewAxis] - )); - assert!(slices_intersect( - &Dim([4, 5]), - s![NewAxis, 0, ..], - s![0, ..] - )); - assert!(slices_intersect( - &Dim([4, 5]), - s![..;2, ..], - s![..;3, NewAxis, ..] - )); - assert!(slices_intersect( - &Dim([4, 5]), - s![.., ..;2], - s![.., 1..;3, NewAxis] - )); + fn slices_intersect_true() { + assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, .., NewAxis, ..], s![.., NewAxis, .., NewAxis])); + assert!(slices_intersect(&Dim([4, 5]), s![NewAxis, 0, ..], s![0, ..])); + assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, NewAxis, ..])); + assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3, NewAxis])); assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6])); } #[test] - fn slices_intersect_false() - { - assert!(!slices_intersect( - &Dim([4, 5]), - s![..;2, ..], - s![NewAxis, 1..;2, ..] - )); - assert!(!slices_intersect( - &Dim([4, 5]), - s![..;2, NewAxis, ..], - s![1..;3, ..] - )); - assert!(!slices_intersect( - &Dim([4, 5]), - s![.., ..;9], - s![.., 3..;6, NewAxis] - )); + fn slices_intersect_false() { + assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![NewAxis, 1..;2, ..])); + assert!(!slices_intersect(&Dim([4, 5]), s![..;2, NewAxis, ..], s![1..;3, ..])); + assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6, NewAxis])); } } diff --git a/src/dimension/ndindex.rs b/src/dimension/ndindex.rs index ca2a3ea6..be7fdafa 100644 --- a/src/dimension/ndindex.rs +++ b/src/dimension/ndindex.rs @@ -17,8 +17,7 @@ use crate::{Dim, Dimension, IntoDimension, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6 /// assert_eq!(a[(1, 1)], 4); /// ``` #[allow(clippy::missing_safety_doc)] // TODO: Add doc -pub unsafe trait NdIndex: Debug -{ +pub unsafe trait NdIndex: Debug { #[doc(hidden)] fn index_checked(&self, dim: &E, strides: &E) -> Option; #[doc(hidden)] @@ -26,134 +25,109 @@ pub unsafe trait NdIndex: Debug } unsafe impl NdIndex for D -where D: Dimension +where + D: Dimension, { - fn index_checked(&self, dim: &D, strides: &D) -> Option - { + fn index_checked(&self, dim: &D, strides: &D) -> Option { dim.stride_offset_checked(strides, self) } - fn index_unchecked(&self, strides: &D) -> isize - { + fn index_unchecked(&self, strides: &D) -> isize { D::stride_offset(self, strides) } } -unsafe impl NdIndex for () -{ +unsafe impl NdIndex for () { #[inline] - fn index_checked(&self, dim: &Ix0, strides: &Ix0) -> Option - { + fn index_checked(&self, dim: &Ix0, strides: &Ix0) -> Option { dim.stride_offset_checked(strides, &Ix0()) } #[inline(always)] - fn index_unchecked(&self, _strides: &Ix0) -> isize - { + fn index_unchecked(&self, _strides: &Ix0) -> isize { 0 } } -unsafe impl NdIndex for (Ix, Ix) -{ +unsafe impl NdIndex for (Ix, Ix) { #[inline] - fn index_checked(&self, dim: &Ix2, strides: &Ix2) -> Option - { + fn index_checked(&self, dim: &Ix2, strides: &Ix2) -> Option { dim.stride_offset_checked(strides, &Ix2(self.0, self.1)) } #[inline] - fn index_unchecked(&self, strides: &Ix2) -> isize - { + fn index_unchecked(&self, strides: &Ix2) -> isize { stride_offset(self.0, get!(strides, 0)) + stride_offset(self.1, get!(strides, 1)) } } -unsafe impl NdIndex for (Ix, Ix, Ix) -{ +unsafe impl NdIndex for (Ix, Ix, Ix) { #[inline] - fn index_checked(&self, dim: &Ix3, strides: &Ix3) -> Option - { + fn index_checked(&self, dim: &Ix3, strides: &Ix3) -> Option { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix3) -> isize - { + fn index_unchecked(&self, strides: &Ix3) -> isize { stride_offset(self.0, get!(strides, 0)) + stride_offset(self.1, get!(strides, 1)) + stride_offset(self.2, get!(strides, 2)) } } -unsafe impl NdIndex for (Ix, Ix, Ix, Ix) -{ +unsafe impl NdIndex for (Ix, Ix, Ix, Ix) { #[inline] - fn index_checked(&self, dim: &Ix4, strides: &Ix4) -> Option - { + fn index_checked(&self, dim: &Ix4, strides: &Ix4) -> Option { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix4) -> isize - { + fn index_unchecked(&self, strides: &Ix4) -> isize { zip(strides.ix(), self.into_dimension().ix()) .map(|(&s, &i)| stride_offset(i, s)) .sum() } } -unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix) -{ +unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix) { #[inline] - fn index_checked(&self, dim: &Ix5, strides: &Ix5) -> Option - { + fn index_checked(&self, dim: &Ix5, strides: &Ix5) -> Option { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix5) -> isize - { + fn index_unchecked(&self, strides: &Ix5) -> isize { zip(strides.ix(), self.into_dimension().ix()) .map(|(&s, &i)| stride_offset(i, s)) .sum() } } -unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix, Ix) -{ +unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix, Ix) { #[inline] - fn index_checked(&self, dim: &Ix6, strides: &Ix6) -> Option - { + fn index_checked(&self, dim: &Ix6, strides: &Ix6) -> Option { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix6) -> isize - { + fn index_unchecked(&self, strides: &Ix6) -> isize { zip(strides.ix(), self.into_dimension().ix()) .map(|(&s, &i)| stride_offset(i, s)) .sum() } } -unsafe impl NdIndex for Ix -{ +unsafe impl NdIndex for Ix { #[inline] - fn index_checked(&self, dim: &Ix1, strides: &Ix1) -> Option - { + fn index_checked(&self, dim: &Ix1, strides: &Ix1) -> Option { dim.stride_offset_checked(strides, &Ix1(*self)) } #[inline(always)] - fn index_unchecked(&self, strides: &Ix1) -> isize - { + fn index_unchecked(&self, strides: &Ix1) -> isize { stride_offset(*self, get!(strides, 0)) } } -unsafe impl NdIndex for Ix -{ +unsafe impl NdIndex for Ix { #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option - { + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { debug_assert_eq!(dim.ndim(), 1); stride_offset_checked(dim.ix(), strides.ix(), &[*self]) } #[inline(always)] - fn index_unchecked(&self, strides: &IxDyn) -> isize - { + fn index_unchecked(&self, strides: &IxDyn) -> isize { debug_assert_eq!(strides.ndim(), 1); stride_offset(*self, get!(strides, 0)) } @@ -192,31 +166,16 @@ ndindex_with_array! { } // implement NdIndex for Dim<[Ix; 2]> and so on -unsafe impl NdIndex for Dim<[Ix; N]> -{ +unsafe impl NdIndex for Dim<[Ix; N]> { #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option - { - debug_assert_eq!( - strides.ndim(), - N, - "Attempted to index with {:?} in array with {} axes", - self, - strides.ndim() - ); + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { + debug_assert_eq!(strides.ndim(), N, "Attempted to index with {:?} in array with {} axes", self, strides.ndim()); stride_offset_checked(dim.ix(), strides.ix(), self.ix()) } #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize - { - debug_assert_eq!( - strides.ndim(), - N, - "Attempted to index with {:?} in array with {} axes", - self, - strides.ndim() - ); + fn index_unchecked(&self, strides: &IxDyn) -> isize { + debug_assert_eq!(strides.ndim(), N, "Attempted to index with {:?} in array with {} axes", self, strides.ndim()); (0..N) .map(|i| stride_offset(get!(self, i), get!(strides, i))) .sum() @@ -224,66 +183,43 @@ unsafe impl NdIndex for Dim<[Ix; N]> } // implement NdIndex for [Ix; 2] and so on -unsafe impl NdIndex for [Ix; N] -{ +unsafe impl NdIndex for [Ix; N] { #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option - { - debug_assert_eq!( - strides.ndim(), - N, - "Attempted to index with {:?} in array with {} axes", - self, - strides.ndim() - ); + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { + debug_assert_eq!(strides.ndim(), N, "Attempted to index with {:?} in array with {} axes", self, strides.ndim()); stride_offset_checked(dim.ix(), strides.ix(), self) } #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize - { - debug_assert_eq!( - strides.ndim(), - N, - "Attempted to index with {:?} in array with {} axes", - self, - strides.ndim() - ); + fn index_unchecked(&self, strides: &IxDyn) -> isize { + debug_assert_eq!(strides.ndim(), N, "Attempted to index with {:?} in array with {} axes", self, strides.ndim()); (0..N) .map(|i| stride_offset(self[i], get!(strides, i))) .sum() } } -impl IntoDimension for &[Ix] -{ +impl IntoDimension for &[Ix] { type Dim = IxDyn; - fn into_dimension(self) -> Self::Dim - { + fn into_dimension(self) -> Self::Dim { Dim(IxDynImpl::from(self)) } } -unsafe impl NdIndex for &IxDyn -{ - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option - { +unsafe impl NdIndex for &IxDyn { + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { (**self).index_checked(dim, strides) } - fn index_unchecked(&self, strides: &IxDyn) -> isize - { + fn index_unchecked(&self, strides: &IxDyn) -> isize { (**self).index_unchecked(strides) } } -unsafe impl NdIndex for &[Ix] -{ - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option - { +unsafe impl NdIndex for &[Ix] { + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { stride_offset_checked(dim.ix(), strides.ix(), self) } - fn index_unchecked(&self, strides: &IxDyn) -> isize - { + fn index_unchecked(&self, strides: &IxDyn) -> isize { zip(strides.ix(), *self) .map(|(&s, &i)| stride_offset(i, s)) .sum() diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs index 1365ab48..dd23216f 100644 --- a/src/dimension/ops.rs +++ b/src/dimension/ops.rs @@ -1,8 +1,7 @@ use crate::imp_prelude::*; /// Adds the two dimensions at compile time. -pub trait DimAdd -{ +pub trait DimAdd { /// The sum of the two dimensions. type Output: Dimension; } @@ -28,8 +27,7 @@ macro_rules! impl_dimadd_const_out_dyn { }; } -impl DimAdd for Ix0 -{ +impl DimAdd for Ix0 { type Output = D; } @@ -87,7 +85,6 @@ impl_dimadd_const_out_dyn!(6, 5); impl_dimadd_const_out_dyn!(6, 6); impl_dimadd_const_out_dyn!(6, IxDyn); -impl DimAdd for IxDyn -{ +impl DimAdd for IxDyn { type Output = IxDyn; } diff --git a/src/dimension/remove_axis.rs b/src/dimension/remove_axis.rs index 7ba3b533..b56bb085 100644 --- a/src/dimension/remove_axis.rs +++ b/src/dimension/remove_axis.rs @@ -12,27 +12,22 @@ use crate::{Axis, Dim, Dimension, Ix, Ix0, Ix1}; /// /// `RemoveAxis` defines a larger-than relation for array shapes: /// removing one axis from *Self* gives smaller dimension *Smaller*. -pub trait RemoveAxis: Dimension -{ +pub trait RemoveAxis: Dimension { /// Remove the specified axis from a dimension. fn remove_axis(&self, axis: Axis) -> Self::Smaller; } -impl RemoveAxis for Dim<[Ix; 1]> -{ +impl RemoveAxis for Dim<[Ix; 1]> { #[inline] - fn remove_axis(&self, axis: Axis) -> Ix0 - { + fn remove_axis(&self, axis: Axis) -> Ix0 { debug_assert!(axis.index() < self.ndim()); Ix0() } } -impl RemoveAxis for Dim<[Ix; 2]> -{ +impl RemoveAxis for Dim<[Ix; 2]> { #[inline] - fn remove_axis(&self, axis: Axis) -> Ix1 - { + fn remove_axis(&self, axis: Axis) -> Ix1 { let axis = axis.index(); debug_assert!(axis < self.ndim()); if axis == 0 { diff --git a/src/dimension/reshape.rs b/src/dimension/reshape.rs index abcec499..b6f5d9c1 100644 --- a/src/dimension/reshape.rs +++ b/src/dimension/reshape.rs @@ -146,21 +146,18 @@ where #[cfg(feature = "std")] #[test] -fn test_reshape() -{ +fn test_reshape() { use crate::Dim; macro_rules! test_reshape { (fail $order:ident from $from:expr, $stride:expr, to $to:expr) => { let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); - println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", - $from, $stride, $to, Order::$order, res); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", $from, $stride, $to, Order::$order, res); let _res = res.expect_err("Expected failed reshape"); }; (ok $order:ident from $from:expr, $stride:expr, to $to:expr, $to_stride:expr) => {{ let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); - println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", - $from, $stride, $to, Order::$order, res); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", $from, $stride, $to, Order::$order, res); println!("default stride for from dim: {:?}", Dim($from).default_strides()); println!("default stride for to dim: {:?}", Dim($to).default_strides()); let res = res.expect("Expected successful reshape"); diff --git a/src/dimension/sequence.rs b/src/dimension/sequence.rs index ed3605d5..533e8b4f 100644 --- a/src/dimension/sequence.rs +++ b/src/dimension/sequence.rs @@ -7,77 +7,76 @@ pub(in crate::dimension) struct Forward(pub(crate) D); pub(in crate::dimension) struct Reverse(pub(crate) D); impl Index for Forward<&D> -where D: Dimension +where + D: Dimension, { type Output = usize; #[inline] - fn index(&self, index: usize) -> &usize - { + fn index(&self, index: usize) -> &usize { &self.0[index] } } impl Index for Forward<&mut D> -where D: Dimension +where + D: Dimension, { type Output = usize; #[inline] - fn index(&self, index: usize) -> &usize - { + fn index(&self, index: usize) -> &usize { &self.0[index] } } impl IndexMut for Forward<&mut D> -where D: Dimension +where + D: Dimension, { #[inline] - fn index_mut(&mut self, index: usize) -> &mut usize - { + fn index_mut(&mut self, index: usize) -> &mut usize { &mut self.0[index] } } impl Index for Reverse<&D> -where D: Dimension +where + D: Dimension, { type Output = usize; #[inline] - fn index(&self, index: usize) -> &usize - { + fn index(&self, index: usize) -> &usize { &self.0[self.len() - index - 1] } } impl Index for Reverse<&mut D> -where D: Dimension +where + D: Dimension, { type Output = usize; #[inline] - fn index(&self, index: usize) -> &usize - { + fn index(&self, index: usize) -> &usize { &self.0[self.len() - index - 1] } } impl IndexMut for Reverse<&mut D> -where D: Dimension +where + D: Dimension, { #[inline] - fn index_mut(&mut self, index: usize) -> &mut usize - { + fn index_mut(&mut self, index: usize) -> &mut usize { let len = self.len(); &mut self.0[len - index - 1] } } /// Indexable sequence with length -pub(in crate::dimension) trait Sequence: Index -{ +pub(in crate::dimension) trait Sequence: Index { fn len(&self) -> usize; } @@ -85,21 +84,21 @@ pub(in crate::dimension) trait Sequence: Index pub(in crate::dimension) trait SequenceMut: Sequence + IndexMut {} impl Sequence for Forward<&D> -where D: Dimension +where + D: Dimension, { #[inline] - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.ndim() } } impl Sequence for Forward<&mut D> -where D: Dimension +where + D: Dimension, { #[inline] - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.ndim() } } @@ -107,21 +106,21 @@ where D: Dimension impl SequenceMut for Forward<&mut D> where D: Dimension {} impl Sequence for Reverse<&D> -where D: Dimension +where + D: Dimension, { #[inline] - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.ndim() } } impl Sequence for Reverse<&mut D> -where D: Dimension +where + D: Dimension, { #[inline] - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.ndim() } } diff --git a/src/error.rs b/src/error.rs index eb7395ad..c4549614 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,24 +12,20 @@ use std::fmt; /// An error related to array shape or layout. #[derive(Clone)] -pub struct ShapeError -{ +pub struct ShapeError { // we want to be able to change this representation later repr: ErrorKind, } -impl ShapeError -{ +impl ShapeError { /// Return the `ErrorKind` of this error. #[inline] - pub fn kind(&self) -> ErrorKind - { + pub fn kind(&self) -> ErrorKind { self.repr } /// Create a new `ShapeError` - pub fn from_kind(error: ErrorKind) -> Self - { + pub fn from_kind(error: ErrorKind) -> Self { from_kind(error) } } @@ -40,8 +36,7 @@ impl ShapeError /// is not guaranteed. #[non_exhaustive] #[derive(Copy, Clone, Debug)] -pub enum ErrorKind -{ +pub enum ErrorKind { /// incompatible shape IncompatibleShape = 1, /// incompatible memory layout @@ -57,25 +52,20 @@ pub enum ErrorKind } #[inline(always)] -pub fn from_kind(k: ErrorKind) -> ShapeError -{ +pub fn from_kind(k: ErrorKind) -> ShapeError { ShapeError { repr: k } } -impl PartialEq for ErrorKind -{ +impl PartialEq for ErrorKind { #[inline(always)] - fn eq(&self, rhs: &Self) -> bool - { + fn eq(&self, rhs: &Self) -> bool { *self as u8 == *rhs as u8 } } -impl PartialEq for ShapeError -{ +impl PartialEq for ShapeError { #[inline(always)] - fn eq(&self, rhs: &Self) -> bool - { + fn eq(&self, rhs: &Self) -> bool { self.repr == rhs.repr } } @@ -83,10 +73,8 @@ impl PartialEq for ShapeError #[cfg(feature = "std")] impl Error for ShapeError {} -impl fmt::Display for ShapeError -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Display for ShapeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let description = match self.kind() { ErrorKind::IncompatibleShape => "incompatible shapes", ErrorKind::IncompatibleLayout => "incompatible memory layout", @@ -99,10 +87,8 @@ impl fmt::Display for ShapeError } } -impl fmt::Debug for ShapeError -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Debug for ShapeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) } } diff --git a/src/extension/nonnull.rs b/src/extension/nonnull.rs index 08f80927..043abee1 100644 --- a/src/extension/nonnull.rs +++ b/src/extension/nonnull.rs @@ -3,8 +3,7 @@ use alloc::vec::Vec; use std::ptr::NonNull; /// Return a NonNull pointer to the vector's data -pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull -{ +pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull { // this pointer is guaranteed to be non-null unsafe { NonNull::new_unchecked(v.as_mut_ptr()) } } @@ -15,8 +14,7 @@ pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull /// This is checked with a debug assertion, and will panic if this is not true, /// but treat this as an unconditional conversion. #[inline] -pub(crate) unsafe fn nonnull_debug_checked_from_ptr(ptr: *mut T) -> NonNull -{ +pub(crate) unsafe fn nonnull_debug_checked_from_ptr(ptr: *mut T) -> NonNull { debug_assert!(!ptr.is_null()); NonNull::new_unchecked(ptr) } diff --git a/src/finite_bounds.rs b/src/finite_bounds.rs index 565fe2bc..b2cd2aa6 100644 --- a/src/finite_bounds.rs +++ b/src/finite_bounds.rs @@ -1,42 +1,38 @@ use num_traits::Float; -pub enum Bound -{ +pub enum Bound { Included(F), Excluded(F), } /// A version of std::ops::RangeBounds that only implements a..b and a..=b ranges. -pub trait FiniteBounds -{ +pub trait FiniteBounds { fn start_bound(&self) -> F; fn end_bound(&self) -> Bound; } impl FiniteBounds for std::ops::Range -where F: Float +where + F: Float, { - fn start_bound(&self) -> F - { + fn start_bound(&self) -> F { self.start } - fn end_bound(&self) -> Bound - { + fn end_bound(&self) -> Bound { Bound::Excluded(self.end) } } impl FiniteBounds for std::ops::RangeInclusive -where F: Float +where + F: Float, { - fn start_bound(&self) -> F - { + fn start_bound(&self) -> F { *self.start() } - fn end_bound(&self) -> Bound - { + fn end_bound(&self) -> Bound { Bound::Included(*self.end()) } } diff --git a/src/free_functions.rs b/src/free_functions.rs index 4ad69f2c..4db3d39a 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -87,26 +87,22 @@ macro_rules! array { } /// Create a zero-dimensional array with the element `x`. -pub fn arr0(x: A) -> Array0 -{ +pub fn arr0(x: A) -> Array0 { unsafe { ArrayBase::from_shape_vec_unchecked((), vec![x]) } } /// Create a one-dimensional array with elements from `xs`. -pub fn arr1(xs: &[A]) -> Array1 -{ +pub fn arr1(xs: &[A]) -> Array1 { ArrayBase::from(xs.to_vec()) } /// Create a one-dimensional array with elements from `xs`. -pub fn rcarr1(xs: &[A]) -> ArcArray1 -{ +pub fn rcarr1(xs: &[A]) -> ArcArray1 { arr1(xs).into_shared() } /// Create a zero-dimensional array view borrowing `x`. -pub const fn aview0(x: &A) -> ArrayView0<'_, A> -{ +pub const fn aview0(x: &A) -> ArrayView0<'_, A> { ArrayBase { data: ViewRepr::new(), parts: ArrayPartsSized::new( @@ -139,13 +135,9 @@ pub const fn aview0(x: &A) -> ArrayView0<'_, A> /// /// assert_eq!(C.sum(), 6.); /// ``` -pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> -{ +pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> { if size_of::() == 0 { - assert!( - xs.len() <= isize::MAX as usize, - "Slice length must fit in `isize`.", - ); + assert!(xs.len() <= isize::MAX as usize, "Slice length must fit in `isize`.",); } ArrayBase { data: ViewRepr::new(), @@ -176,26 +168,20 @@ pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> /// const C: ArrayView2<'static, f64> = aview2(&[[1., 2., 3.], [4., 5., 6.]]); /// assert_eq!(C.sum(), 21.); /// ``` -pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> -{ +pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> { let cols = N; let rows = xs.len(); if size_of::() == 0 { if let Some(n_elems) = rows.checked_mul(cols) { assert!( - rows <= isize::MAX as usize - && cols <= isize::MAX as usize - && n_elems <= isize::MAX as usize, + rows <= isize::MAX as usize && cols <= isize::MAX as usize && n_elems <= isize::MAX as usize, "Product of non-zero axis lengths must not overflow isize.", ); } else { panic!("Overflow in number of elements."); } } else if N == 0 { - assert!( - rows <= isize::MAX as usize, - "Product of non-zero axis lengths must not overflow isize.", - ); + assert!(rows <= isize::MAX as usize, "Product of non-zero axis lengths must not overflow isize.",); } // Safe because references are always non-null. let ptr = unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }; @@ -223,8 +209,7 @@ pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> /// } /// assert_eq!(&data[..10], [5, 0, 0, 5, 0, 0, 5, 0, 0, 5]); /// ``` -pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> -{ +pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> { ArrayViewMut::from(xs) } @@ -250,8 +235,7 @@ pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> /// // look at the start of the result /// assert_eq!(&data[..3], [[1., -1.], [1., -1.], [1., -1.]]); /// ``` -pub fn aview_mut2(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> -{ +pub fn aview_mut2(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> { ArrayViewMut2::from(xs) } @@ -266,8 +250,7 @@ pub fn aview_mut2(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> /// a.shape() == [2, 3] /// ); /// ``` -pub fn arr2(xs: &[[A; N]]) -> Array2 -{ +pub fn arr2(xs: &[[A; N]]) -> Array2 { Array2::from(xs.to_vec()) } @@ -307,8 +290,7 @@ impl_from_nested_vec!([[[[[A; J]; K]; L]; M]; N], Ix6, N, M, L, K, J); /// Create a two-dimensional array with elements from `xs`. /// -pub fn rcarr2(xs: &[[A; N]]) -> ArcArray2 -{ +pub fn rcarr2(xs: &[[A; N]]) -> ArcArray2 { arr2(xs).into_shared() } @@ -329,14 +311,12 @@ pub fn rcarr2(xs: &[[A; N]]) -> ArcArray2 /// a.shape() == [3, 2, 2] /// ); /// ``` -pub fn arr3(xs: &[[[A; M]; N]]) -> Array3 -{ +pub fn arr3(xs: &[[[A; M]; N]]) -> Array3 { Array3::from(xs.to_vec()) } /// Create a three-dimensional array with elements from `xs`. -pub fn rcarr3(xs: &[[[A; M]; N]]) -> ArcArray -{ +pub fn rcarr3(xs: &[[[A; M]; N]]) -> ArcArray { arr3(xs).into_shared() } @@ -344,8 +324,7 @@ pub fn rcarr3(xs: &[[[A; M]; N]]) -> A /// /// Controls whether the first argument to `meshgrid` will fill the rows or columns of the outputs. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MeshIndex -{ +pub enum MeshIndex { /// Cartesian indexing. /// /// The first argument of `meshgrid` will repeat over the columns of the output. @@ -358,32 +337,20 @@ pub enum MeshIndex IJ, } -mod meshgrid_impl -{ +mod meshgrid_impl { use super::MeshIndex; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::{ - ArrayBase, - ArrayRef1, - ArrayView, - ArrayView2, - ArrayView3, - ArrayView4, - ArrayView5, - ArrayView6, - Axis, - Data, - Dim, - IntoDimension, - Ix1, - LayoutRef1, + ArrayBase, ArrayRef1, ArrayView, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, Axis, Data, Dim, + IntoDimension, Ix1, LayoutRef1, }; /// Construct the correct strides for the `idx`-th entry into meshgrid fn construct_strides( arr: &LayoutRef1, idx: usize, indexing: MeshIndex, ) -> <[usize; N] as IntoDimension>::Dim - where [usize; N]: IntoDimension + where + [usize; N]: IntoDimension, { let mut ret = [0; N]; if idx < 2 && indexing == MeshIndex::XY { @@ -398,7 +365,8 @@ mod meshgrid_impl fn construct_shape( arrays: [&LayoutRef1; N], indexing: MeshIndex, ) -> <[usize; N] as IntoDimension>::Dim - where [usize; N]: IntoDimension + where + [usize; N]: IntoDimension, { let mut ret = arrays.map(|a| a.len()); if indexing == MeshIndex::XY { @@ -413,8 +381,7 @@ mod meshgrid_impl /// The outputs should always be ND arrays where N is the number of inputs. /// /// Where possible, this trait tries to return array views rather than allocating additional memory. - pub trait Meshgrid - { + pub trait Meshgrid { type Output; fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output; @@ -434,12 +401,10 @@ mod meshgrid_impl }; } - impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1) - { + impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1) { type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1)) } } @@ -451,18 +416,15 @@ mod meshgrid_impl { type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing) } } - impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1) - { + impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1) { type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2)) } } @@ -476,18 +438,15 @@ mod meshgrid_impl { type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing) } } - impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1) - { + impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1) { type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3)) } } @@ -502,8 +461,7 @@ mod meshgrid_impl { type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing) } } @@ -513,8 +471,7 @@ mod meshgrid_impl { type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4)) } } @@ -536,8 +493,7 @@ mod meshgrid_impl { type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing) } } @@ -561,9 +517,17 @@ mod meshgrid_impl ArrayView6<'f, A>, ); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { - meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5)) + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { + meshgrid_body!( + 6, + indexing, + (arrays.0, 0), + (arrays.1, 1), + (arrays.2, 2), + (arrays.3, 3), + (arrays.4, 4), + (arrays.5, 5) + ) } } @@ -593,8 +557,7 @@ mod meshgrid_impl ArrayView6<'f, A>, ); - fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output - { + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output { Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing) } } @@ -649,22 +612,19 @@ mod meshgrid_impl /// [5, 6]], /// ]); /// ``` -pub fn meshgrid(arrays: T, indexing: MeshIndex) -> T::Output -{ +pub fn meshgrid(arrays: T, indexing: MeshIndex) -> T::Output { Meshgrid::meshgrid(arrays, indexing) } #[cfg(test)] -mod tests -{ +mod tests { use super::s; use crate::{meshgrid, Axis, MeshIndex}; #[cfg(not(feature = "std"))] use alloc::vec; #[test] - fn test_meshgrid2() - { + fn test_meshgrid2() { let x = array![1, 2, 3]; let y = array![4, 5, 6, 7]; let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); @@ -677,52 +637,68 @@ mod tests } #[test] - fn test_meshgrid3() - { + fn test_meshgrid3() { let x = array![1, 2, 3]; let y = array![4, 5, 6, 7]; let z = array![-1, -2]; let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY); - assert_eq!(xx, array![ - [[1, 1], [2, 2], [3, 3]], - [[1, 1], [2, 2], [3, 3]], - [[1, 1], [2, 2], [3, 3]], - [[1, 1], [2, 2], [3, 3]], - ]); - assert_eq!(yy, array![ - [[4, 4], [4, 4], [4, 4]], - [[5, 5], [5, 5], [5, 5]], - [[6, 6], [6, 6], [6, 6]], - [[7, 7], [7, 7], [7, 7]], - ]); - assert_eq!(zz, array![ - [[-1, -2], [-1, -2], [-1, -2]], - [[-1, -2], [-1, -2], [-1, -2]], - [[-1, -2], [-1, -2], [-1, -2]], - [[-1, -2], [-1, -2], [-1, -2]], - ]); + assert_eq!( + xx, + array![ + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + ] + ); + assert_eq!( + yy, + array![ + [[4, 4], [4, 4], [4, 4]], + [[5, 5], [5, 5], [5, 5]], + [[6, 6], [6, 6], [6, 6]], + [[7, 7], [7, 7], [7, 7]], + ] + ); + assert_eq!( + zz, + array![ + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + ] + ); let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ); - assert_eq!(xx, array![ - [[1, 1], [1, 1], [1, 1], [1, 1]], - [[2, 2], [2, 2], [2, 2], [2, 2]], - [[3, 3], [3, 3], [3, 3], [3, 3]], - ]); - assert_eq!(yy, array![ - [[4, 4], [5, 5], [6, 6], [7, 7]], - [[4, 4], [5, 5], [6, 6], [7, 7]], - [[4, 4], [5, 5], [6, 6], [7, 7]], - ]); - assert_eq!(zz, array![ - [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], - [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], - [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], - ]); + assert_eq!( + xx, + array![ + [[1, 1], [1, 1], [1, 1], [1, 1]], + [[2, 2], [2, 2], [2, 2], [2, 2]], + [[3, 3], [3, 3], [3, 3], [3, 3]], + ] + ); + assert_eq!( + yy, + array![ + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + ] + ); + assert_eq!( + zz, + array![ + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + ] + ); } #[test] - fn test_meshgrid_from_offset() - { + fn test_meshgrid_from_offset() { let x = array![1, 2, 3]; let x = x.slice(s![1..]); let y = array![4, 5, 6]; @@ -733,8 +709,7 @@ mod tests } #[test] - fn test_meshgrid_neg_stride() - { + fn test_meshgrid_neg_stride() { let x = array![1, 2, 3]; let x = x.slice(s![..;-1]); assert!(x.stride_of(Axis(0)) < 0); // Setup for test diff --git a/src/geomspace.rs b/src/geomspace.rs index 26a44f82..f13cdc9b 100644 --- a/src/geomspace.rs +++ b/src/geomspace.rs @@ -11,8 +11,7 @@ use num_traits::Float; /// An iterator of a sequence of geometrically spaced floats. /// /// Iterator element type is `F`. -pub struct Geomspace -{ +pub struct Geomspace { sign: F, start: F, step: F, @@ -21,13 +20,13 @@ pub struct Geomspace } impl Iterator for Geomspace -where F: Float +where + F: Float, { type Item = F; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if self.index >= self.len { None } else { @@ -40,19 +39,18 @@ where F: Float } #[inline] - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Geomspace -where F: Float +where + F: Float, { #[inline] - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { if self.index >= self.len { None } else { @@ -82,7 +80,8 @@ impl ExactSizeIterator for Geomspace where Geomspace: Iterator {} /// **Panics** if converting `n - 1` to type `F` fails. #[inline] pub fn geomspace(a: F, b: F, n: usize) -> Option> -where F: Float +where + F: Float, { if a == F::zero() || b == F::zero() || a.is_sign_negative() != b.is_sign_negative() { return None; @@ -105,14 +104,12 @@ where F: Float } #[cfg(test)] -mod tests -{ +mod tests { use super::geomspace; #[test] #[cfg(feature = "approx")] - fn valid() - { + fn valid() { use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; @@ -130,8 +127,7 @@ mod tests } #[test] - fn iter_forward() - { + fn iter_forward() { let mut iter = geomspace(1.0f64, 1e3, 4).unwrap(); assert!(iter.size_hint() == (4, Some(4))); @@ -146,8 +142,7 @@ mod tests } #[test] - fn iter_backward() - { + fn iter_backward() { let mut iter = geomspace(1.0f64, 1e3, 4).unwrap(); assert!(iter.size_hint() == (4, Some(4))); @@ -162,20 +157,17 @@ mod tests } #[test] - fn zero_lower() - { + fn zero_lower() { assert!(geomspace(0.0, 1.0, 4).is_none()); } #[test] - fn zero_upper() - { + fn zero_upper() { assert!(geomspace(1.0, 0.0, 4).is_none()); } #[test] - fn zero_included() - { + fn zero_included() { assert!(geomspace(-1.0, 1.0, 4).is_none()); } } diff --git a/src/hpc/aabb.rs b/src/hpc/aabb.rs index 91eaaf90..43770fe8 100644 --- a/src/hpc/aabb.rs +++ b/src/hpc/aabb.rs @@ -107,11 +107,7 @@ impl Ray { pub fn new(origin: [f32; 3], direction: [f32; 3]) -> Self { Self { origin, - inv_dir: [ - 1.0 / direction[0], - 1.0 / direction[1], - 1.0 / direction[2], - ], + inv_dir: [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]], } } @@ -177,7 +173,7 @@ fn aabb_intersect_batch_scalar(query: &Aabb, candidates: &[Aabb]) -> Vec { #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f")] unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec { - use crate::simd::{F32x16}; + use crate::simd::F32x16; let mut result = Vec::with_capacity(candidates.len()); @@ -485,11 +481,7 @@ pub fn aabb_squared_distance_batch(point: [f32; 3], aabbs: &[Aabb]) -> Vec /// Filter AABBs by maximum squared distance from a point. Returns indices /// of AABBs whose nearest point is within `max_sq_dist` of `point`. -pub fn aabb_filter_by_distance( - point: [f32; 3], - aabbs: &[Aabb], - max_sq_dist: f32, -) -> Vec { +pub fn aabb_filter_by_distance(point: [f32; 3], aabbs: &[Aabb], max_sq_dist: f32) -> Vec { let distances = aabb_squared_distance_batch(point, aabbs); distances .iter() @@ -586,10 +578,10 @@ mod tests { fn test_intersect_batch() { let query = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); let candidates = vec![ - Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]), // yes - Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]), // no + Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]), // yes + Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]), // no Aabb::new([-1.0, -1.0, -1.0], [0.5, 0.5, 0.5]), // yes - Aabb::new([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]), // yes (touching) + Aabb::new([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]), // yes (touching) ]; let results = aabb_intersect_batch(&query, &candidates); assert_eq!(results, vec![true, false, true, true]); @@ -619,10 +611,7 @@ mod tests { #[test] fn test_expand_batch() { - let mut aabbs = vec![ - Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), - Aabb::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0]), - ]; + let mut aabbs = vec![Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), Aabb::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0])]; aabb_expand_batch(&mut aabbs, 0.5, 1.0, 1.5); assert!(approx_eq(aabbs[0].min[0], -0.5)); @@ -646,18 +635,8 @@ mod tests { for (i, orig) in base.iter().enumerate() { let expected = orig.expand(0.25, 0.5, 0.75); for axis in 0..3 { - assert!( - approx_eq(batch[i].min[axis], expected.min[axis]), - "min mismatch at [{},{}]", - i, - axis - ); - assert!( - approx_eq(batch[i].max[axis], expected.max[axis]), - "max mismatch at [{},{}]", - i, - axis - ); + assert!(approx_eq(batch[i].min[axis], expected.min[axis]), "min mismatch at [{},{}]", i, axis); + assert!(approx_eq(batch[i].max[axis], expected.max[axis]), "max mismatch at [{},{}]", i, axis); } } } @@ -689,21 +668,19 @@ mod tests { #[test] fn test_squared_distance_batch() { - let aabbs = vec![ - Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), - Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]), - ]; + let aabbs = + vec![Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0])]; let dists = aabb_squared_distance_batch([0.5, 0.5, 0.5], &aabbs); - assert!(approx_eq(dists[0], 0.0)); // inside - assert!(dists[1] > 200.0); // far away + assert!(approx_eq(dists[0], 0.0)); // inside + assert!(dists[1] > 200.0); // far away } #[test] fn test_filter_by_distance() { let aabbs = vec![ - Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // 0: dist=0 - Aabb::new([2.0, 0.0, 0.0], [3.0, 1.0, 1.0]), // 1: nearest pt (2,0.5,0.5), dist=1.5, sq=2.25 - Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]),// 2: far + Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // 0: dist=0 + Aabb::new([2.0, 0.0, 0.0], [3.0, 1.0, 1.0]), // 1: nearest pt (2,0.5,0.5), dist=1.5, sq=2.25 + Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]), // 2: far ]; let indices = aabb_filter_by_distance([0.5, 0.5, 0.5], &aabbs, 5.0); assert_eq!(indices, vec![0, 1]); @@ -711,19 +688,14 @@ mod tests { #[test] fn test_filter_by_distance_none() { - let aabbs = vec![ - Aabb::new([100.0, 100.0, 100.0], [101.0, 101.0, 101.0]), - ]; + let aabbs = vec![Aabb::new([100.0, 100.0, 100.0], [101.0, 101.0, 101.0])]; let indices = aabb_filter_by_distance([0.0, 0.0, 0.0], &aabbs, 1.0); assert!(indices.is_empty()); } #[test] fn test_filter_by_distance_all() { - let aabbs = vec![ - Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), - Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]), - ]; + let aabbs = vec![Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5])]; let indices = aabb_filter_by_distance([0.7, 0.7, 0.7], &aabbs, 100.0); assert_eq!(indices, vec![0, 1]); } @@ -846,7 +818,8 @@ mod tests { assert!( approx_eq(ts_batch[i], ts_scalar[i]), "ray AVX-512 t parity at {i}: {} vs {}", - ts_batch[i], ts_scalar[i] + ts_batch[i], + ts_scalar[i] ); } } diff --git a/src/hpc/activations.rs b/src/hpc/activations.rs index fb548a11..a4ea3f97 100644 --- a/src/hpc/activations.rs +++ b/src/hpc/activations.rs @@ -2,8 +2,8 @@ //! //! Generic trait impl via `mapv` + standalone SIMD-accelerated f32 slice functions. -use crate::simd::{simd_exp_f32, F32x16}; use crate::imp_prelude::*; +use crate::simd::{simd_exp_f32, F32x16}; use num_traits::Float; /// Neural network activation functions. @@ -52,7 +52,11 @@ where fn log_softmax(&self) -> Array { let max_val = self.iter().fold(A::neg_infinity(), |a, &b| a.max(b)); let shifted = self.mapv(|v| v - max_val); - let log_sum_exp = shifted.mapv(|v| v.exp()).iter().fold(A::zero(), |acc, &v| acc + v).ln(); + let log_sum_exp = shifted + .mapv(|v| v.exp()) + .iter() + .fold(A::zero(), |acc, &v| acc + v) + .ln(); shifted.mapv(|v| v - log_sum_exp) } } @@ -90,7 +94,9 @@ pub fn sigmoid_f32(x: &[f32], out: &mut [f32]) { /// Numerically stable. Uses F32x16 for exp and reduce_sum. pub fn softmax_f32(x: &[f32], out: &mut [f32]) { let n = x.len().min(out.len()); - if n == 0 { return; } + if n == 0 { + return; + } // Pass 1: find max (SIMD reduce_max) let mut max_acc = F32x16::splat(f32::NEG_INFINITY); @@ -143,7 +149,9 @@ pub fn softmax_f32(x: &[f32], out: &mut [f32]) { /// Numerically stable. Single pass for max, single pass for sum-exp. pub fn log_softmax_f32(x: &[f32], out: &mut [f32]) { let n = x.len().min(out.len()); - if n == 0 { return; } + if n == 0 { + return; + } // Pass 1: find max let mut max_acc = F32x16::splat(f32::NEG_INFINITY); @@ -273,8 +281,13 @@ mod tests { log_softmax_f32(&x, &mut logsm_out); for i in 0..10 { let expected = softmax_out[i].ln(); - assert!((logsm_out[i] - expected).abs() < 1e-2, - "log_softmax[{}] = {}, expected {}", i, logsm_out[i], expected); + assert!( + (logsm_out[i] - expected).abs() < 1e-2, + "log_softmax[{}] = {}, expected {}", + i, + logsm_out[i], + expected + ); } } diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index 2ce323ac..a6838f0f 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -207,7 +207,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) { // strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked // into contiguous staging buffers before the kernel runs. -use crate::hpc::quantized::{BF16, bf16_gemm_f32, int8_gemm_i32}; +use crate::hpc::quantized::{bf16_gemm_f32, int8_gemm_i32, BF16}; use crate::{ArrayView2, ArrayViewMut2}; /// Errors returned by the public AMX matmul API. @@ -470,7 +470,7 @@ mod tests { // ── Public matmul API tests (sprint A4) ──────────────────────────────── use crate::hpc::quantized::BF16; - use crate::{Array2, s}; + use crate::{s, Array2}; /// Reference f32 matmul, fully scalar. fn ref_matmul_f32(a: &Array2, b: &Array2) -> Array2 { diff --git a/src/hpc/arrow_bridge.rs b/src/hpc/arrow_bridge.rs index 7ff1a7cf..e1502e4d 100644 --- a/src/hpc/arrow_bridge.rs +++ b/src/hpc/arrow_bridge.rs @@ -382,11 +382,7 @@ impl BindNodeV2 { } // Recompute SPO XOR - self.spo_binary = Self::compute_spo_xor( - &self.subject_binary, - &self.predicate_binary, - &self.object_binary, - ); + self.spo_binary = Self::compute_spo_xor(&self.subject_binary, &self.predicate_binary, &self.object_binary); // Null soaking self.subject_soaking = None; @@ -559,9 +555,7 @@ impl BindNodeV2 { /// Compute the XOR of three plane binaries. fn compute_spo_xor( - s: &[u8; PLANE_BINARY_BYTES], - p: &[u8; PLANE_BINARY_BYTES], - o: &[u8; PLANE_BINARY_BYTES], + s: &[u8; PLANE_BINARY_BYTES], p: &[u8; PLANE_BINARY_BYTES], o: &[u8; PLANE_BINARY_BYTES], ) -> [u8; PLANE_BINARY_BYTES] { let mut result = [0u8; PLANE_BINARY_BYTES]; for i in 0..PLANE_BINARY_BYTES { @@ -817,7 +811,8 @@ mod tests { #[test] fn crystallize_positive() { let mut buf = SoakingBuffer::new(1, 8); - buf.entry_mut(0).copy_from_slice(&[1, -1, 1, -1, 1, -1, 1, -1]); + buf.entry_mut(0) + .copy_from_slice(&[1, -1, 1, -1, 1, -1, 1, -1]); let bits = buf.crystallize(0); assert_eq!(bits[0], 0b01010101); } @@ -923,11 +918,7 @@ mod tests { fn bind_node_v2_spo_xor_is_correct() { let (mut s, mut p, mut o) = make_test_planes(); let node = BindNodeV2::new(&mut s, &mut p, &mut o, "test"); - let expected = BindNodeV2::compute_spo_xor( - &node.subject_binary, - &node.predicate_binary, - &node.object_binary, - ); + let expected = BindNodeV2::compute_spo_xor(&node.subject_binary, &node.predicate_binary, &node.object_binary); assert_eq!(node.spo_binary, expected); assert_eq!(node.spo_xor(), expected); } @@ -1124,11 +1115,7 @@ mod tests { let spo_after = node.spo_binary; // After crystallize, SPO should be recomputed from updated binaries - let expected = BindNodeV2::compute_spo_xor( - &node.subject_binary, - &node.predicate_binary, - &node.object_binary, - ); + let expected = BindNodeV2::compute_spo_xor(&node.subject_binary, &node.predicate_binary, &node.object_binary); assert_eq!(spo_after, expected); // SPO may or may not change depending on soaking content; // what matters is consistency @@ -1261,7 +1248,10 @@ mod tests { #[test] fn soaking_row_buffer_crystallize() { let mut buf = SoakingRowBuffer::new(8); - buf.data.as_mut().unwrap().copy_from_slice(&[1, -1, 1, -1, 1, -1, 1, -1]); + buf.data + .as_mut() + .unwrap() + .copy_from_slice(&[1, -1, 1, -1, 1, -1, 1, -1]); let bits = buf.crystallize(); assert_eq!(bits[0], 0b01010101); assert!(!buf.is_active()); // should be nulled after crystallize diff --git a/src/hpc/audio/bands.rs b/src/hpc/audio/bands.rs index 97526bd4..9ccc9626 100644 --- a/src/hpc/audio/bands.rs +++ b/src/hpc/audio/bands.rs @@ -9,10 +9,8 @@ /// Opus CELT band boundaries at 48kHz, 960-sample frames (480 MDCT bins). /// 22 boundaries define 21 bands. Bin index = frequency / (48000 / 960). /// Band 0: bins 0-3 (~0-200 Hz), Band 20: bins 400-480 (~20-24 kHz). -pub const CELT_BANDS_48K: [usize; 22] = [ - 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 44, 52, 60, 68, 80, 96, - 112, 136, 160, 200, 256, 480, -]; +pub const CELT_BANDS_48K: [usize; 22] = + [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 44, 52, 60, 68, 80, 96, 112, 136, 160, 200, 256, 480]; /// Number of critical bands. pub const N_BANDS: usize = 21; @@ -121,15 +119,15 @@ mod tests { let recovered = denormalize_bands(&shape, &e); for (orig, rec) in coeffs.iter().zip(recovered.iter()) { - assert!((orig - rec).abs() < 0.01, - "Roundtrip mismatch: {} vs {}", orig, rec); + assert!((orig - rec).abs() < 0.01, "Roundtrip mismatch: {} vs {}", orig, rec); } } #[test] fn bf16_energy_roundtrip() { - let e = [1.0, 0.5, 2.0, 0.001, 100.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let e = [ + 1.0, 0.5, 2.0, 0.001, 100.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; let bf16 = energies_to_bf16(&e); let recovered = bf16_to_energies(&bf16); for i in 0..5 { diff --git a/src/hpc/audio/codec.rs b/src/hpc/audio/codec.rs index 02526415..58a0c8d4 100644 --- a/src/hpc/audio/codec.rs +++ b/src/hpc/audio/codec.rs @@ -6,8 +6,8 @@ //! //! One AudioFrame = one graph node in lance-graph. 48 bytes = CAM-compatible. -use super::mdct; use super::bands; +use super::mdct; use super::pvq; /// One audio frame: 42 bytes gain + 6 bytes shape = 48 bytes. @@ -103,7 +103,10 @@ impl AudioFrame { } let mut pvq_summary = [0u8; 6]; pvq_summary.copy_from_slice(&bytes[42..48]); - AudioFrame { band_energies, pvq_summary } + AudioFrame { + band_energies, + pvq_summary, + } } } @@ -127,7 +130,9 @@ mod tests { let frame = AudioFrame::encode(&pcm, 8); // Band energies should be nonzero (at least the band containing 440Hz) - let total_energy: f32 = frame.band_energies.iter() + let total_energy: f32 = frame + .band_energies + .iter() .map(|&b| f32::from_bits((b as u32) << 16)) .sum(); assert!(total_energy > 0.01, "Encoded frame has no energy: {}", total_energy); diff --git a/src/hpc/audio/codec_map.rs b/src/hpc/audio/codec_map.rs index df664327..526eabf1 100644 --- a/src/hpc/audio/codec_map.rs +++ b/src/hpc/audio/codec_map.rs @@ -116,7 +116,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "CELT MDCT: 960-sample window → 480 frequency bins", what_it_replaces: "FFT+windowing (all codecs use some form)", }, - // ═══ From Whisper ═══ Provenance { our_type: "mel::log_mel_spectrogram", @@ -126,7 +125,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "80-channel mel filterbank at 16kHz, Hann STFT", what_it_replaces: "Transformer encoder (150M params → 80 f32 per frame)", }, - // ═══ From MP3 ═══ Provenance { our_type: "HhtlCache::route() → Skip", @@ -144,7 +142,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "32-subband polyphase filterbank (octave-spaced)", what_it_replaces: "Per-subband quantization + Huffman (MP3 granules)", }, - // ═══ From Ogg Vorbis ═══ Provenance { our_type: "CompiledLinear (ndarray burn)", @@ -154,7 +151,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "VQ codebook: precomputed centroids, lookup-based decode", what_it_replaces: "Huffman trees (MP3) / arithmetic coding (Opus range coder)", }, - // ═══ From Bark (Suno) ═══ Provenance { our_type: "RvqFrame.archetype (HEEL)", @@ -180,7 +176,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "Stage 3: non-autoregressive fine acoustic tokens", what_it_replaces: "Fine model (smaller network, fills spectral detail)", }, - // ═══ From ElevenLabs ═══ Provenance { our_type: "VoiceArchetype", @@ -190,7 +185,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "Speaker embedding (voice cloning conditioning vector)", what_it_replaces: "512-dim speaker embedding (2KB → 16 bytes)", }, - // ═══ Phase (novel — no codec stores this) ═══ Provenance { our_type: "PhaseDescriptor", @@ -200,7 +194,6 @@ pub const PROVENANCE: &[Provenance] = &[ source_concept: "STFT phase (discarded by all codecs except Griffin-Lim)", what_it_replaces: "Nothing — all codecs discard phase. We keep it as relative pressure.", }, - // ═══ Qualia (novel — derived from QPL musical calibration) ═══ Provenance { our_type: "Qualia17D", @@ -232,21 +225,22 @@ pub const FRAME_BUDGET_WITH_TTS: usize = 69; /// These are approximate — our codec is lossy in a fundamentally /// different way (palette quantization, not psychoacoustic masking). pub const BITRATE_COMPARISON: &[(&str, u32, &str)] = &[ - ("MP3 128k", 128_000, "psychoacoustic masking, Huffman"), - ("Opus 64k", 64_000, "CELT+SILK hybrid, range coder"), - ("Vorbis 128k", 128_000, "MDCT, floor+residue, VQ codebook"), - ("Bark tokens", 25_600, "3-stage RVQ, ~100 tokens/sec × 256 bits"), - ("Ours (48kHz)", 20_800, "52 bytes × 50 fps × 8 bits = 20.8 kbps"), - ("Ours (24kHz)", 10_400, "52 bytes × 25 fps × 8 bits = 10.4 kbps"), + ("MP3 128k", 128_000, "psychoacoustic masking, Huffman"), + ("Opus 64k", 64_000, "CELT+SILK hybrid, range coder"), + ("Vorbis 128k", 128_000, "MDCT, floor+residue, VQ codebook"), + ("Bark tokens", 25_600, "3-stage RVQ, ~100 tokens/sec × 256 bits"), + ("Ours (48kHz)", 20_800, "52 bytes × 50 fps × 8 bits = 20.8 kbps"), + ("Ours (24kHz)", 10_400, "52 bytes × 25 fps × 8 bits = 10.4 kbps"), ]; /// Verify every AudioAspect is covered by at least one primitive. /// If an aspect is missing, we have a hole in our codec design. pub fn verify_aspect_coverage() -> Vec { use AudioAspect::*; - let all = [SpectralEnvelope, SpectralShape, PerceptualMapping, - PhaseRelationship, SpeakerIdentity, SemanticContent, - MaskingDecision, CodebookLookup]; + let all = [ + SpectralEnvelope, SpectralShape, PerceptualMapping, PhaseRelationship, SpeakerIdentity, SemanticContent, + MaskingDecision, CodebookLookup, + ]; all.iter() .filter(|&&aspect| !PROVENANCE.iter().any(|p| p.aspect == aspect)) @@ -275,36 +269,64 @@ mod tests { #[test] fn provenance_byte_sizes_consistent() { // AudioFrame = 42 (energies) + 6 (pvq) = 48 - let af_energies = PROVENANCE.iter().find(|p| p.our_type == "AudioFrame.band_energies").unwrap(); - let af_pvq = PROVENANCE.iter().find(|p| p.our_type == "AudioFrame.pvq_summary").unwrap(); + let af_energies = PROVENANCE + .iter() + .find(|p| p.our_type == "AudioFrame.band_energies") + .unwrap(); + let af_pvq = PROVENANCE + .iter() + .find(|p| p.our_type == "AudioFrame.pvq_summary") + .unwrap(); assert_eq!(af_energies.byte_size + af_pvq.byte_size, 48); // RvqFrame = 1 (HEEL) + 8 (HIP) + 8 (TWIG) = 17 - let rvq_heel = PROVENANCE.iter().find(|p| p.our_type == "RvqFrame.archetype (HEEL)").unwrap(); - let rvq_hip = PROVENANCE.iter().find(|p| p.our_type == "RvqFrame.coarse (HIP)").unwrap(); - let rvq_twig = PROVENANCE.iter().find(|p| p.our_type == "RvqFrame.fine (TWIG)").unwrap(); + let rvq_heel = PROVENANCE + .iter() + .find(|p| p.our_type == "RvqFrame.archetype (HEEL)") + .unwrap(); + let rvq_hip = PROVENANCE + .iter() + .find(|p| p.our_type == "RvqFrame.coarse (HIP)") + .unwrap(); + let rvq_twig = PROVENANCE + .iter() + .find(|p| p.our_type == "RvqFrame.fine (TWIG)") + .unwrap(); assert_eq!(rvq_heel.byte_size + rvq_hip.byte_size + rvq_twig.byte_size, 17); } #[test] fn every_source_codec_represented() { // All 6 source codecs should appear at least once - for source in [CodecSource::Opus, CodecSource::Whisper, CodecSource::Mp3, - CodecSource::OggVorbis, CodecSource::Bark, CodecSource::ElevenLabs] { - assert!(PROVENANCE.iter().any(|p| p.source == source), - "Codec {:?} not represented in provenance table", source); + for source in [ + CodecSource::Opus, + CodecSource::Whisper, + CodecSource::Mp3, + CodecSource::OggVorbis, + CodecSource::Bark, + CodecSource::ElevenLabs, + ] { + assert!( + PROVENANCE.iter().any(|p| p.source == source), + "Codec {:?} not represented in provenance table", + source + ); } } #[test] fn our_bitrate_competitive() { // Our codec should be lower bitrate than all traditional codecs - let ours_24k = BITRATE_COMPARISON.iter() + let ours_24k = BITRATE_COMPARISON + .iter() .find(|&&(name, _, _)| name == "Ours (24kHz)") - .unwrap().1; - let mp3 = BITRATE_COMPARISON.iter() + .unwrap() + .1; + let mp3 = BITRATE_COMPARISON + .iter() .find(|&&(name, _, _)| name == "MP3 128k") - .unwrap().1; + .unwrap() + .1; assert!(ours_24k < mp3, "Our codec should be lower bitrate than MP3"); } } diff --git a/src/hpc/audio/mdct.rs b/src/hpc/audio/mdct.rs index 970a969b..e80c9dc4 100644 --- a/src/hpc/audio/mdct.rs +++ b/src/hpc/audio/mdct.rs @@ -20,7 +20,9 @@ pub const MDCT_SIZE: usize = FRAME_SIZE / 2; /// Sine window for MDCT (Opus uses a sine window for CELT mode). /// w[n] = sin(π/N × (n + 0.5)) pub fn sine_window(n: usize) -> Vec { - (0..n).map(|i| (PI / n as f32 * (i as f32 + 0.5)).sin()).collect() + (0..n) + .map(|i| (PI / n as f32 * (i as f32 + 0.5)).sin()) + .collect() } /// Forward MDCT: time-domain → frequency-domain. @@ -135,13 +137,19 @@ pub fn mdct_backward(coeffs: &[f32]) -> Vec { // Unfold to symmetric positions let idx_a = 2 * k; let idx_b = n - 1 - 2 * k; - if idx_a < n { output[idx_a] = y_re * window[idx_a]; } - if idx_b < n { output[idx_b] = y_im * window[idx_b]; } + if idx_a < n { + output[idx_a] = y_re * window[idx_a]; + } + if idx_b < n { + output[idx_b] = y_im * window[idx_b]; + } } // Scale (MDCT normalization: 2/N) let scale = 2.0 / n as f32; - for s in &mut output { *s *= scale; } + for s in &mut output { + *s *= scale; + } output } diff --git a/src/hpc/audio/mel.rs b/src/hpc/audio/mel.rs index d45c3e4f..1c035f36 100644 --- a/src/hpc/audio/mel.rs +++ b/src/hpc/audio/mel.rs @@ -65,7 +65,8 @@ pub fn build_mel_filters(sample_rate: usize, n_fft: usize, n_mels: usize) -> Vec // Convert mel points back to Hz, then to FFT bin indices let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); - let bin_points: Vec = hz_points.iter() + let bin_points: Vec = hz_points + .iter() .map(|&h| h * n_fft as f32 / sample_rate as f32) .collect(); @@ -95,7 +96,9 @@ pub fn build_mel_filters(sample_rate: usize, n_fft: usize, n_mels: usize) -> Vec /// Hann window for STFT. pub fn hann_window(n: usize) -> Vec { - (0..n).map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos())).collect() + (0..n) + .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos())) + .collect() } /// Compute magnitude spectrogram via STFT. @@ -124,7 +127,7 @@ pub fn stft_magnitude(pcm: &[f32], window_size: usize, hop_size: usize) -> Vec u32 { match self { - Mode::Ionian => 8, // Gate: broad, confident - Mode::Dorian => 5, // V: warm content - Mode::Phrygian => 3, // QK: tight, exotic - Mode::Lydian => 2, // Up: fine, dreamy - Mode::Mixolydian => 4, // Down: driving compression - Mode::Aeolian => 3, // QK: minor, offset start - Mode::Locrian => 8, // Gate: unstable, offset start + Mode::Ionian => 8, // Gate: broad, confident + Mode::Dorian => 5, // V: warm content + Mode::Phrygian => 3, // QK: tight, exotic + Mode::Lydian => 2, // Up: fine, dreamy + Mode::Mixolydian => 4, // Down: driving compression + Mode::Aeolian => 3, // QK: minor, offset start + Mode::Locrian => 8, // Gate: unstable, offset start } } @@ -68,13 +68,13 @@ impl Mode { /// transposing the key. pub fn start_offset(&self) -> u32 { match self { - Mode::Ionian => 0, - Mode::Dorian => 2, - Mode::Phrygian => 4, - Mode::Lydian => 5, + Mode::Ionian => 0, + Mode::Dorian => 2, + Mode::Phrygian => 4, + Mode::Lydian => 5, Mode::Mixolydian => 7, - Mode::Aeolian => 9, - Mode::Locrian => 11, + Mode::Aeolian => 9, + Mode::Locrian => 11, } } @@ -84,13 +84,13 @@ impl Mode { /// This is more accurate than 12-EDO for both fifths and thirds. pub fn intervals_17edo(&self) -> [u8; 7] { match self { - Mode::Ionian => [3, 3, 2, 3, 3, 3, 0], // W W H W W W (last H implicit) - Mode::Dorian => [3, 2, 3, 3, 3, 2, 1], // W H W W W H W-1 - Mode::Phrygian => [2, 3, 3, 3, 2, 3, 1], // H W W W H W W-1 - Mode::Lydian => [3, 3, 3, 2, 3, 3, 0], // W W W H W W (last H implicit) + Mode::Ionian => [3, 3, 2, 3, 3, 3, 0], // W W H W W W (last H implicit) + Mode::Dorian => [3, 2, 3, 3, 3, 2, 1], // W H W W W H W-1 + Mode::Phrygian => [2, 3, 3, 3, 2, 3, 1], // H W W W H W W-1 + Mode::Lydian => [3, 3, 3, 2, 3, 3, 0], // W W W H W W (last H implicit) Mode::Mixolydian => [3, 3, 2, 3, 3, 2, 1], // W W H W W H W-1 - Mode::Aeolian => [3, 2, 3, 3, 2, 3, 1], // W H W W H W W-1 - Mode::Locrian => [2, 3, 3, 2, 3, 3, 1], // H W W H W W W-1 + Mode::Aeolian => [3, 2, 3, 3, 2, 3, 1], // W H W W H W W-1 + Mode::Locrian => [2, 3, 3, 2, 3, 3, 1], // H W W H W W W-1 } } @@ -101,13 +101,13 @@ impl Mode { /// high tension → less skipping (preserve detail). pub fn tension(&self) -> f32 { match self { - Mode::Ionian => 0.1, // most resolved - Mode::Lydian => 0.2, // floating but stable - Mode::Mixolydian => 0.3, // dominant tension - Mode::Dorian => 0.4, // warm but minor - Mode::Aeolian => 0.6, // sad minor - Mode::Phrygian => 0.8, // dark, exotic - Mode::Locrian => 1.0, // maximum instability + Mode::Ionian => 0.1, // most resolved + Mode::Lydian => 0.2, // floating but stable + Mode::Mixolydian => 0.3, // dominant tension + Mode::Dorian => 0.4, // warm but minor + Mode::Aeolian => 0.6, // sad minor + Mode::Phrygian => 0.8, // dark, exotic + Mode::Locrian => 1.0, // maximum instability } } } @@ -127,35 +127,53 @@ pub fn mode_band_weights(mode: Mode) -> [f32; bands::N_BANDS] { match mode { Mode::Ionian => { // Bright: boost presence (bands 10-14, ~2-5 kHz) - for i in 10..=14 { weights[i] = 1.3; } + for i in 10..=14 { + weights[i] = 1.3; + } } Mode::Dorian => { // Warm: boost low-mid (bands 4-8, ~800-1800 Hz) - for i in 4..=8 { weights[i] = 1.2; } + for i in 4..=8 { + weights[i] = 1.2; + } } Mode::Phrygian => { // Dark: boost sub-bass (bands 0-3), cut presence - for i in 0..=3 { weights[i] = 1.4; } - for i in 10..=14 { weights[i] = 0.7; } + for i in 0..=3 { + weights[i] = 1.4; + } + for i in 10..=14 { + weights[i] = 0.7; + } } Mode::Lydian => { // Shimmering: boost harmonics (bands 14-18, ~5-13 kHz) - for i in 14..=18 { weights[i] = 1.3; } + for i in 14..=18 { + weights[i] = 1.3; + } } Mode::Mixolydian => { // Driving: boost fundamental + mid (bands 2-6, ~400-1400 Hz) - for i in 2..=6 { weights[i] = 1.25; } + for i in 2..=6 { + weights[i] = 1.25; + } } Mode::Aeolian => { // Sad: slight low emphasis, gentle roll-off - for i in 0..=5 { weights[i] = 1.15; } - for i in 16..=20 { weights[i] = 0.85; } + for i in 0..=5 { + weights[i] = 1.15; + } + for i in 16..=20 { + weights[i] = 0.85; + } } Mode::Locrian => { // Unstable: emphasize dissonant regions weights[6] = 1.4; // ~1400 Hz tritone region weights[13] = 1.3; // ~3400 Hz - for i in 0..=2 { weights[i] = 0.8; } // weaken root + for i in 0..=2 { + weights[i] = 0.8; + } // weaken root } } @@ -182,21 +200,21 @@ pub fn apply_mode(energies: &mut [f32; bands::N_BANDS], mode: Mode) { /// create natural-sounding prosody contours. pub fn circle_of_fifths_progression() -> Vec<(Mode, u32)> { vec![ - (Mode::Ionian, 0), // I (tonic, resolved) - (Mode::Lydian, 5), // IV (subdominant, floating) - (Mode::Mixolydian, 7), // V (dominant, driving) - (Mode::Ionian, 0), // I (return to tonic) + (Mode::Ionian, 0), // I (tonic, resolved) + (Mode::Lydian, 5), // IV (subdominant, floating) + (Mode::Mixolydian, 7), // V (dominant, driving) + (Mode::Ionian, 0), // I (return to tonic) ] } /// Minor progression: i → iv → VI → V → i pub fn minor_progression() -> Vec<(Mode, u32)> { vec![ - (Mode::Aeolian, 0), // i (tonic minor) - (Mode::Dorian, 5), // iv (subdominant, warm) - (Mode::Ionian, 8), // VI (relative major, bright) - (Mode::Mixolydian, 7), // V (dominant, driving) - (Mode::Aeolian, 0), // i (return) + (Mode::Aeolian, 0), // i (tonic minor) + (Mode::Dorian, 5), // iv (subdominant, warm) + (Mode::Ionian, 8), // VI (relative major, bright) + (Mode::Mixolydian, 7), // V (dominant, driving) + (Mode::Aeolian, 0), // i (return) ] } @@ -286,8 +304,8 @@ impl OctaveBand { // Build harmonic pattern with given decay rate let pattern = [ - 1.0, // fundamental (always 1.0) - harmonic_decay, // 2nd harmonic + 1.0, // fundamental (always 1.0) + harmonic_decay, // 2nd harmonic harmonic_decay * harmonic_decay, // 3rd harmonic ]; @@ -295,7 +313,10 @@ impl OctaveBand { let sum: f32 = pattern.iter().sum(); let norm = [pattern[0] / sum * 3.0, pattern[1] / sum * 3.0, pattern[2] / sum * 3.0]; - OctaveBand { pattern: norm, octave: octave.min(6) } + OctaveBand { + pattern: norm, + octave: octave.min(6), + } } /// Compress a full 21-band energy vector to octave bands. @@ -308,7 +329,10 @@ impl OctaveBand { /// BUT: if many frames share the same pattern (same pitch class), /// store pattern ONCE + per-frame octave offset = massive savings. pub fn compress_to_octaves(energies: &[f32; bands::N_BANDS]) -> [OctaveBand; 7] { - let mut result = [OctaveBand { pattern: [1.0; 3], octave: 0 }; 7]; + let mut result = [OctaveBand { + pattern: [1.0; 3], + octave: 0, + }; 7]; for oct in 0..7 { let start = oct * Self::BANDS_PER_OCTAVE; let mut pattern = [0.0f32; 3]; @@ -321,9 +345,15 @@ impl OctaveBand { } // Normalize if sum > 1e-10 { - for p in &mut pattern { *p /= sum; *p *= 3.0; } + for p in &mut pattern { + *p /= sum; + *p *= 3.0; + } } - result[oct] = OctaveBand { pattern, octave: oct as u8 }; + result[oct] = OctaveBand { + pattern, + octave: oct as u8, + }; } result } @@ -373,11 +403,11 @@ mod tests { #[test] fn mode_stride_matches_highheelbgz() { // Verify stride→role mapping is consistent with highheelbgz::TensorRole - assert_eq!(Mode::Ionian.stride(), 8); // Gate - assert_eq!(Mode::Dorian.stride(), 5); // V - assert_eq!(Mode::Phrygian.stride(), 3); // QK - assert_eq!(Mode::Lydian.stride(), 2); // Up - assert_eq!(Mode::Mixolydian.stride(), 4); // Down + assert_eq!(Mode::Ionian.stride(), 8); // Gate + assert_eq!(Mode::Dorian.stride(), 5); // V + assert_eq!(Mode::Phrygian.stride(), 3); // QK + assert_eq!(Mode::Lydian.stride(), 2); // Up + assert_eq!(Mode::Mixolydian.stride(), 4); // Down } #[test] @@ -390,12 +420,18 @@ mod tests { #[test] fn band_weights_centered() { // All mode weights should average close to 1.0 - for mode in [Mode::Ionian, Mode::Dorian, Mode::Phrygian, - Mode::Lydian, Mode::Mixolydian, Mode::Aeolian, Mode::Locrian] { + for mode in [ + Mode::Ionian, + Mode::Dorian, + Mode::Phrygian, + Mode::Lydian, + Mode::Mixolydian, + Mode::Aeolian, + Mode::Locrian, + ] { let weights = mode_band_weights(mode); let avg: f32 = weights.iter().sum::() / bands::N_BANDS as f32; - assert!(avg > 0.8 && avg < 1.3, - "Mode {:?} weights avg {:.2} — should be ~1.0", mode, avg); + assert!(avg > 0.8 && avg < 1.3, "Mode {:?} weights avg {:.2} — should be ~1.0", mode, avg); } } @@ -410,8 +446,15 @@ mod tests { #[test] fn intervals_sum_to_17() { // Each mode's intervals should sum close to 17 (one octave in 17-EDO) - for mode in [Mode::Ionian, Mode::Dorian, Mode::Phrygian, - Mode::Lydian, Mode::Mixolydian, Mode::Aeolian, Mode::Locrian] { + for mode in [ + Mode::Ionian, + Mode::Dorian, + Mode::Phrygian, + Mode::Lydian, + Mode::Mixolydian, + Mode::Aeolian, + Mode::Locrian, + ] { let intervals = mode.intervals_17edo(); let sum: u8 = intervals.iter().sum(); // 7 intervals sum to 17 (W=3, H=2): 5W+2H = 5×3+2×2 = 19? @@ -451,8 +494,11 @@ mod tests { energies[11] = 0.25; let octaves = OctaveBand::compress_to_octaves(&energies); // Octave 3 (bands 9-11) should have the most energy in pattern[0] - assert!(octaves[3].pattern[0] > octaves[3].pattern[2], - "Octave 3 pattern should peak at fundamental: {:?}", octaves[3].pattern); + assert!( + octaves[3].pattern[0] > octaves[3].pattern[2], + "Octave 3 pattern should peak at fundamental: {:?}", + octaves[3].pattern + ); // The fundamental (1.0) should have ~57% of the energy (1.0 / 1.75 × 3) assert!(octaves[3].pattern[0] > 1.5, "Fundamental weight should be > 1.5: {}", octaves[3].pattern[0]); } diff --git a/src/hpc/audio/phase.rs b/src/hpc/audio/phase.rs index 348ed800..e0fb362d 100644 --- a/src/hpc/audio/phase.rs +++ b/src/hpc/audio/phase.rs @@ -23,9 +23,9 @@ //! Uses the same STFT from mel.rs but keeps phase info instead of //! discarding it (which is what magnitude spectrograms do). +use super::bands; use crate::hpc::fft; use core::f32::consts::PI; -use super::bands; /// Phase coherence between adjacent harmonics within one frame. /// @@ -34,16 +34,15 @@ use super::bands; /// Noise: random phase relationships (coherence ≈ 0.0). /// /// Returns per-band coherence values [0.0, 1.0]. -pub fn band_phase_coherence( - real: &[f32], - imag: &[f32], -) -> [f32; bands::N_BANDS] { +pub fn band_phase_coherence(real: &[f32], imag: &[f32]) -> [f32; bands::N_BANDS] { let mut coherence = [0.0f32; bands::N_BANDS]; for band in 0..bands::N_BANDS { let lo = bands::CELT_BANDS_48K[band]; let hi = bands::CELT_BANDS_48K[band + 1].min(real.len().min(imag.len())); - if hi <= lo + 1 { continue; } + if hi <= lo + 1 { + continue; + } // Phase differences between adjacent bins within this band let mut cos_sum = 0.0f64; @@ -51,7 +50,9 @@ pub fn band_phase_coherence( let mut count = 0u32; for i in lo..(hi - 1) { - if i >= real.len() || i + 1 >= real.len() { break; } + if i >= real.len() || i + 1 >= real.len() { + break; + } let phase_i = imag[i].atan2(real[i]); let phase_next = imag[i + 1].atan2(real[i + 1]); let diff = phase_next - phase_i; @@ -79,8 +80,7 @@ pub fn band_phase_coherence( /// /// Returns per-band gradient in radians/frame. pub fn phase_gradient( - prev_real: &[f32], prev_imag: &[f32], - curr_real: &[f32], curr_imag: &[f32], + prev_real: &[f32], prev_imag: &[f32], curr_real: &[f32], curr_imag: &[f32], ) -> [f32; bands::N_BANDS] { let mut gradient = [0.0f32; bands::N_BANDS]; @@ -89,19 +89,27 @@ pub fn phase_gradient( let hi = bands::CELT_BANDS_48K[band + 1] .min(prev_real.len()) .min(curr_real.len()); - if hi <= lo { continue; } + if hi <= lo { + continue; + } let mut total_diff = 0.0f64; let mut count = 0u32; for i in lo..hi { - if i >= prev_real.len() || i >= curr_real.len() { break; } + if i >= prev_real.len() || i >= curr_real.len() { + break; + } let prev_phase = prev_imag[i].atan2(prev_real[i]); let curr_phase = curr_imag[i].atan2(curr_real[i]); // Unwrap phase difference to [-π, π] let mut diff = curr_phase - prev_phase; - while diff > PI { diff -= 2.0 * PI; } - while diff < -PI { diff += 2.0 * PI; } + while diff > PI { + diff -= 2.0 * PI; + } + while diff < -PI { + diff += 2.0 * PI; + } total_diff += diff.abs() as f64; count += 1; } @@ -162,9 +170,11 @@ impl PhaseDescriptor { // Gradient stability: std dev of gradients (high = changing pitch) let grad_mean = gradient.iter().sum::() / bands::N_BANDS as f32; - let grad_var = gradient.iter() + let grad_var = gradient + .iter() .map(|g| (g - grad_mean) * (g - grad_mean)) - .sum::() / bands::N_BANDS as f32; + .sum::() + / bands::N_BANDS as f32; let grad_std = grad_var.sqrt(); PhaseDescriptor { @@ -188,11 +198,11 @@ impl PhaseDescriptor { let stability = 1.0 - self.bytes[3] as f32 / 255.0; [ - (9, coherence), // coherence: phase-locked = unified - (4, coherence), // clarity: locked harmonics = clear - (7, gradient), // velocity: phase rotation = movement - (8, coh_entropy), // entropy: mixed voiced/unvoiced - (14, stability), // groundedness: steady pitch = rooted + (9, coherence), // coherence: phase-locked = unified + (4, coherence), // clarity: locked harmonics = clear + (7, gradient), // velocity: phase rotation = movement + (8, coh_entropy), // entropy: mixed voiced/unvoiced + (14, stability), // groundedness: steady pitch = rooted ] } @@ -212,9 +222,7 @@ impl PhaseDescriptor { /// Returns (magnitude_per_frame, real_per_frame, imag_per_frame). /// Each frame has n_fft/2+1 bins. pub fn stft_with_phase( - pcm: &[f32], - window_size: usize, - hop_size: usize, + pcm: &[f32], window_size: usize, hop_size: usize, ) -> (Vec>, Vec>, Vec>) { let n_fft = window_size.next_power_of_two(); let n_bins = n_fft / 2 + 1; @@ -274,7 +282,9 @@ mod tests { .collect(); let (_mags, reals, imags) = stft_with_phase(&pcm, 512, 256); - if reals.is_empty() { return; } + if reals.is_empty() { + return; + } let coh = band_phase_coherence(&reals[0], &imags[0]); // At least one band should have high coherence (the one with 440Hz) @@ -287,13 +297,19 @@ mod tests { // White noise → random phases → low coherence let n = 1024; let mut rng = 0x12345678u64; - let pcm: Vec = (0..n).map(|_| { - rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); - ((rng >> 33) as f32 / (1u64 << 31) as f32) * 2.0 - 1.0 - }).collect(); + let pcm: Vec = (0..n) + .map(|_| { + rng = rng + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + ((rng >> 33) as f32 / (1u64 << 31) as f32) * 2.0 - 1.0 + }) + .collect(); let (_mags, reals, imags) = stft_with_phase(&pcm, 512, 256); - if reals.is_empty() { return; } + if reals.is_empty() { + return; + } let coh = band_phase_coherence(&reals[0], &imags[0]); let mean_coh: f32 = coh.iter().sum::() / bands::N_BANDS as f32; @@ -321,7 +337,9 @@ mod tests { #[test] fn phase_to_qualia_dims_valid() { - let desc = PhaseDescriptor { bytes: [200, 50, 100, 30] }; + let desc = PhaseDescriptor { + bytes: [200, 50, 100, 30], + }; let dims = desc.to_qualia_dims(); for (dim_idx, value) in dims { assert!(dim_idx < 17, "Invalid dim index: {}", dim_idx); diff --git a/src/hpc/audio/pvq.rs b/src/hpc/audio/pvq.rs index 1733a055..ae715d81 100644 --- a/src/hpc/audio/pvq.rs +++ b/src/hpc/audio/pvq.rs @@ -87,7 +87,9 @@ pub fn pvq_summary(pulses: &[i32]) -> [u8; 6] { // HEEL (bytes 0-1): sign pattern of first 16 dims → 16 bits let mut sign_bits = 0u16; for i in 0..n.min(16) { - if pulses[i] > 0 { sign_bits |= 1 << i; } + if pulses[i] > 0 { + sign_bits |= 1 << i; + } } summary[0] = sign_bits as u8; summary[1] = (sign_bits >> 8) as u8; @@ -102,12 +104,17 @@ pub fn pvq_summary(pulses: &[i32]) -> [u8; 6] { let total = quarter_energy.iter().sum::().max(1); for i in 0..4 { let frac = (quarter_energy[i] * 255 / total) as u8; - if i < 2 { summary[2] |= frac >> (4 * (1 - i)); } - else { summary[3] |= frac >> (4 * (3 - i)); } + if i < 2 { + summary[2] |= frac >> (4 * (1 - i)); + } else { + summary[3] |= frac >> (4 * (3 - i)); + } } // TWIG (bytes 4-5): max pulse position + magnitude - let (max_pos, max_val) = pulses.iter().enumerate() + let (max_pos, max_val) = pulses + .iter() + .enumerate() .max_by_key(|(_, &p)| p.unsigned_abs()) .map(|(i, &p)| (i, p.unsigned_abs())) .unwrap_or((0, 0)); @@ -135,8 +142,14 @@ mod tests { // Dominant pulse signs should match input signs for i in 0..band.len() { if band[i].abs() > 0.3 { - assert_eq!(pulses[i].signum(), band[i].signum() as i32, - "Sign mismatch at dim {}: pulse={}, band={}", i, pulses[i], band[i]); + assert_eq!( + pulses[i].signum(), + band[i].signum() as i32, + "Sign mismatch at dim {}: pulse={}, band={}", + i, + pulses[i], + band[i] + ); } } } diff --git a/src/hpc/audio/synth.rs b/src/hpc/audio/synth.rs index c72c406e..5691cd6b 100644 --- a/src/hpc/audio/synth.rs +++ b/src/hpc/audio/synth.rs @@ -18,11 +18,11 @@ //! applied at step 3: band energies are scaled by the QPL family's //! spectral EQ before synthesis. -use super::codec::AudioFrame; use super::bands; -use super::voice::{VoiceArchetype, VoiceCodebook, VoiceFrame, RvqFrame}; -use super::phase::PhaseDescriptor; +use super::codec::AudioFrame; use super::modes; +use super::phase::PhaseDescriptor; +use super::voice::{RvqFrame, VoiceArchetype, VoiceCodebook, VoiceFrame}; /// Decode a sequence of VoiceFrames into PCM audio. /// @@ -35,12 +35,11 @@ use super::modes; /// /// Returns mono f32 PCM samples. pub fn synthesize( - frames: &[VoiceFrame], - codebook: &VoiceCodebook, - coarse_centroids: &[[u16; bands::N_BANDS]; 256], - sample_rate: u32, + frames: &[VoiceFrame], codebook: &VoiceCodebook, coarse_centroids: &[[u16; bands::N_BANDS]; 256], sample_rate: u32, ) -> Vec { - if frames.is_empty() { return vec![]; } + if frames.is_empty() { + return vec![]; + } // Frame parameters (Opus CELT compatible) let frame_samples = 960; // 20ms at 48kHz @@ -93,7 +92,8 @@ pub fn synthesize( // Resample if needed (our MDCT produces at 48kHz, caller may want 24kHz) if sample_rate == 24000 { // Simple 2:1 decimation with averaging - output = output.chunks(2) + output = output + .chunks(2) .map(|c| if c.len() == 2 { (c[0] + c[1]) * 0.5 } else { c[0] }) .collect(); } @@ -113,22 +113,19 @@ pub fn synthesize( /// code[5]: bands 15-17 (brilliance) /// code[6]: bands 18-20 (air) /// code[7]: global gain (scales all bands) -fn reconstruct_band_energies( - rvq: &RvqFrame, - centroids: &[[u16; bands::N_BANDS]; 256], -) -> [u16; bands::N_BANDS] { +fn reconstruct_band_energies(rvq: &RvqFrame, centroids: &[[u16; bands::N_BANDS]; 256]) -> [u16; bands::N_BANDS] { // Start with the centroid pointed to by code[0] (base spectral shape) let base = centroids[rvq.coarse[0] as usize]; let mut energies = base; // Blend in contributions from other coarse codes per band group - let band_groups: [(usize, usize); 7] = [ - (0, 3), (3, 6), (6, 9), (9, 12), (12, 15), (15, 18), (18, 21), - ]; + let band_groups: [(usize, usize); 7] = [(0, 3), (3, 6), (6, 9), (9, 12), (12, 15), (15, 18), (18, 21)]; for (group_idx, &(lo, hi)) in band_groups.iter().enumerate() { let code_idx = group_idx + 1; - if code_idx >= 8 { break; } + if code_idx >= 8 { + break; + } let centroid = ¢roids[rvq.coarse[code_idx] as usize]; for band in lo..hi.min(bands::N_BANDS) { // Weighted blend: 60% base + 40% group-specific centroid @@ -159,12 +156,12 @@ fn reconstruct_band_energies( /// bytes 4-5: harmonic detail (from fine[5..8]) fn fine_to_pvq_summary(fine: &[u8; 8]) -> [u8; 6] { [ - fine[0] ^ fine[1], // sign pattern XOR - fine[1] ^ fine[2], // sign pattern continuation - fine[2], // temporal gradient - fine[3] ^ fine[4], // temporal modulation - fine[5], // harmonic detail - fine[6] ^ fine[7], // harmonic modulation + fine[0] ^ fine[1], // sign pattern XOR + fine[1] ^ fine[2], // sign pattern continuation + fine[2], // temporal gradient + fine[3] ^ fine[4], // temporal modulation + fine[5], // harmonic detail + fine[6] ^ fine[7], // harmonic modulation ] } @@ -223,7 +220,7 @@ pub fn write_wav(pcm: &[f32], sample_rate: u32) -> Vec { // fmt sub-chunk wav.extend_from_slice(b"fmt "); wav.extend_from_slice(&16u32.to_le_bytes()); // sub-chunk size - wav.extend_from_slice(&1u16.to_le_bytes()); // PCM format + wav.extend_from_slice(&1u16.to_le_bytes()); // PCM format wav.extend_from_slice(&n_channels.to_le_bytes()); wav.extend_from_slice(&sample_rate.to_le_bytes()); wav.extend_from_slice(&byte_rate.to_le_bytes()); @@ -235,7 +232,11 @@ pub fn write_wav(pcm: &[f32], sample_rate: u32) -> Vec { wav.extend_from_slice(&data_size.to_le_bytes()); // Normalize and convert to i16 - let max_abs = pcm.iter().map(|s| s.abs()).fold(0.0f32, f32::max).max(1e-10); + let max_abs = pcm + .iter() + .map(|s| s.abs()) + .fold(0.0f32, f32::max) + .max(1e-10); let scale = 32767.0 / max_abs; for &sample in pcm { @@ -248,10 +249,18 @@ pub fn write_wav(pcm: &[f32], sample_rate: u32) -> Vec { /// Validate a WAV byte buffer (basic sanity check). pub fn validate_wav(wav: &[u8]) -> Result<(u32, usize), &'static str> { - if wav.len() < 44 { return Err("WAV too short"); } - if &wav[0..4] != b"RIFF" { return Err("Missing RIFF header"); } - if &wav[8..12] != b"WAVE" { return Err("Missing WAVE format"); } - if &wav[12..16] != b"fmt " { return Err("Missing fmt chunk"); } + if wav.len() < 44 { + return Err("WAV too short"); + } + if &wav[0..4] != b"RIFF" { + return Err("Missing RIFF header"); + } + if &wav[8..12] != b"WAVE" { + return Err("Missing WAVE format"); + } + if &wav[12..16] != b"fmt " { + return Err("Missing fmt chunk"); + } let sample_rate = u32::from_le_bytes([wav[24], wav[25], wav[26], wav[27]]); let data_start = 44; // standard PCM WAV @@ -288,7 +297,9 @@ mod tests { #[test] fn synthesize_empty_returns_empty() { - let codebook = VoiceCodebook { entries: vec![VoiceArchetype::zero()] }; + let codebook = VoiceCodebook { + entries: vec![VoiceArchetype::zero()], + }; let centroids = [[0u16; bands::N_BANDS]; 256]; let pcm = synthesize(&[], &codebook, ¢roids, 48000); assert!(pcm.is_empty()); @@ -296,7 +307,9 @@ mod tests { #[test] fn synthesize_single_frame() { - let codebook = VoiceCodebook { entries: vec![VoiceArchetype::zero(); 256] }; + let codebook = VoiceCodebook { + entries: vec![VoiceArchetype::zero(); 256], + }; // Create centroids with some energy in mid-bands let mut centroids = [[0u16; bands::N_BANDS]; 256]; for c in centroids.iter_mut() { @@ -307,8 +320,14 @@ mod tests { } let frame = VoiceFrame { - rvq: RvqFrame { archetype: 0, coarse: [0, 0, 0, 0, 0, 0, 0, 128], fine: [128; 8] }, - phase: PhaseDescriptor { bytes: [200, 30, 128, 50] }, // voiced, steady + rvq: RvqFrame { + archetype: 0, + coarse: [0, 0, 0, 0, 0, 0, 0, 128], + fine: [128; 8], + }, + phase: PhaseDescriptor { + bytes: [200, 30, 128, 50], + }, // voiced, steady }; let pcm = synthesize(&[frame], &codebook, ¢roids, 48000); @@ -331,14 +350,23 @@ mod tests { for band in 0..bands::N_BANDS { energies[band] = (0.5f32.to_bits() >> 16) as u16; } - let frame = AudioFrame { band_energies: energies, pvq_summary: [0; 6] }; - let voiced = PhaseDescriptor { bytes: [255, 30, 128, 50] }; // high coherence + let frame = AudioFrame { + band_energies: energies, + pvq_summary: [0; 6], + }; + let voiced = PhaseDescriptor { + bytes: [255, 30, 128, 50], + }; // high coherence let modulated = phase_modulate_frame(&frame, &voiced); // Mid-bands (4-14) should be boosted - let mid_orig: f32 = (4..=14).map(|b| f32::from_bits((frame.band_energies[b] as u32) << 16)).sum(); - let mid_mod: f32 = (4..=14).map(|b| f32::from_bits((modulated.band_energies[b] as u32) << 16)).sum(); + let mid_orig: f32 = (4..=14) + .map(|b| f32::from_bits((frame.band_energies[b] as u32) << 16)) + .sum(); + let mid_mod: f32 = (4..=14) + .map(|b| f32::from_bits((modulated.band_energies[b] as u32) << 16)) + .sum(); assert!(mid_mod > mid_orig, "Voiced phase should boost mid-bands: {} vs {}", mid_mod, mid_orig); } @@ -352,13 +380,21 @@ mod tests { let audio_frame = AudioFrame::encode(&pcm, 8); // Build a codebook with this frame's energies as the only centroid - let codebook = VoiceCodebook { entries: vec![VoiceArchetype::zero(); 256] }; + let codebook = VoiceCodebook { + entries: vec![VoiceArchetype::zero(); 256], + }; let mut centroids = [[0u16; bands::N_BANDS]; 256]; centroids[0] = audio_frame.band_energies; let voice_frame = VoiceFrame { - rvq: RvqFrame { archetype: 0, coarse: [0, 0, 0, 0, 0, 0, 0, 128], fine: [0; 8] }, - phase: PhaseDescriptor { bytes: [200, 30, 128, 50] }, + rvq: RvqFrame { + archetype: 0, + coarse: [0, 0, 0, 0, 0, 0, 0, 128], + fine: [0; 8], + }, + phase: PhaseDescriptor { + bytes: [200, 30, 128, 50], + }, }; let decoded = synthesize(&[voice_frame], &codebook, ¢roids, 48000); diff --git a/src/hpc/audio/voice.rs b/src/hpc/audio/voice.rs index c5cba037..6539f541 100644 --- a/src/hpc/audio/voice.rs +++ b/src/hpc/audio/voice.rs @@ -42,7 +42,9 @@ impl VoiceArchetype { /// Zero archetype (neutral voice). pub fn zero() -> Self { - VoiceArchetype { channels: [0i8; N_VOICE_CHANNELS] } + VoiceArchetype { + channels: [0i8; N_VOICE_CHANNELS], + } } /// L1 distance between two archetypes. @@ -68,7 +70,11 @@ impl VoiceArchetype { nb += b * b; } let denom = ((na as f64) * (nb as f64)).sqrt(); - if denom < 1e-12 { 0.0 } else { dot as f64 / denom } + if denom < 1e-12 { + 0.0 + } else { + dot as f64 / denom + } } /// Extract archetype from raw embedding by quantizing to 16 channels. @@ -83,7 +89,8 @@ impl VoiceArchetype { let mut channels = [0i8; N_VOICE_CHANNELS]; // Find scale factor for quantization to i8 range - let max_abs = embedding.iter() + let max_abs = embedding + .iter() .map(|v| v.abs()) .fold(0.0f32, f32::max) .max(1e-10); @@ -129,12 +136,16 @@ impl VoiceArchetype { /// Articulation quality (channels 8-11 magnitude). pub fn articulation_energy(&self) -> u32 { - (8..12).map(|i| self.channels[i].unsigned_abs() as u32).sum() + (8..12) + .map(|i| self.channels[i].unsigned_abs() as u32) + .sum() } /// Prosody quality (channels 12-15 magnitude). pub fn prosody_energy(&self) -> u32 { - (12..16).map(|i| self.channels[i].unsigned_abs() as u32).sum() + (12..16) + .map(|i| self.channels[i].unsigned_abs() as u32) + .sum() } /// Modulate archetype with phase dynamics. @@ -186,7 +197,8 @@ pub struct VoiceCodebook { impl VoiceCodebook { /// Build from raw embeddings (e.g., from Bark speaker prompts). pub fn build(embeddings: &[Vec], stride: usize) -> Self { - let entries: Vec = embeddings.iter() + let entries: Vec = embeddings + .iter() .map(|e| VoiceArchetype::from_embedding(e, stride)) .collect(); VoiceCodebook { entries } @@ -266,7 +278,11 @@ impl RvqFrame { let mut fine = [0u8; 8]; coarse.copy_from_slice(&bytes[1..9]); fine.copy_from_slice(&bytes[9..17]); - RvqFrame { archetype: bytes[0], coarse, fine } + RvqFrame { + archetype: bytes[0], + coarse, + fine, + } } /// HEEL check: same voice archetype? @@ -343,15 +359,17 @@ mod tests { #[test] fn archetype_self_distance_zero() { - let a = VoiceArchetype { channels: [10, -20, 30, -40, 50, -60, 70, -80, - 90, -100, 110, -120, 5, -15, 25, -35] }; + let a = VoiceArchetype { + channels: [10, -20, 30, -40, 50, -60, 70, -80, 90, -100, 110, -120, 5, -15, 25, -35], + }; assert_eq!(a.l1(&a), 0); } #[test] fn archetype_self_cosine_one() { - let a = VoiceArchetype { channels: [10, -20, 30, -40, 50, -60, 70, -80, - 1, 2, 3, 4, 5, 6, 7, 8] }; + let a = VoiceArchetype { + channels: [10, -20, 30, -40, 50, -60, 70, -80, 1, 2, 3, 4, 5, 6, 7, 8], + }; let c = a.cosine(&a); assert!((c - 1.0).abs() < 1e-10, "Self cosine should be 1.0: {}", c); } @@ -367,8 +385,9 @@ mod tests { #[test] fn archetype_serialize_roundtrip() { - let a = VoiceArchetype { channels: [1, -2, 3, -4, 5, -6, 7, -8, - 9, -10, 11, -12, 13, -14, 15, -16] }; + let a = VoiceArchetype { + channels: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16], + }; let bytes = a.to_bytes(); let recovered = VoiceArchetype::from_bytes(&bytes); assert_eq!(a, recovered); @@ -402,23 +421,36 @@ mod tests { #[test] fn phase_modulation_changes_articulation() { - let base = VoiceArchetype { channels: [0, 0, 0, 0, 0, 0, 0, 0, - 50, 50, 50, 50, 0, 0, 0, 0] }; + let base = VoiceArchetype { + channels: [0, 0, 0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 0, 0, 0, 0], + }; // High coherence → should boost articulation channels - let high_coh = super::super::phase::PhaseDescriptor { bytes: [255, 128, 128, 128] }; + let high_coh = super::super::phase::PhaseDescriptor { + bytes: [255, 128, 128, 128], + }; let modulated = base.modulate_with_phase(&high_coh); // Articulation channels (8-11) should be boosted - let base_art: i32 = (8..12).map(|i| base.channels[i].unsigned_abs() as i32).sum(); - let mod_art: i32 = (8..12).map(|i| modulated.channels[i].unsigned_abs() as i32).sum(); + let base_art: i32 = (8..12) + .map(|i| base.channels[i].unsigned_abs() as i32) + .sum(); + let mod_art: i32 = (8..12) + .map(|i| modulated.channels[i].unsigned_abs() as i32) + .sum(); assert!(mod_art >= base_art, "High coherence should boost articulation: {} vs {}", mod_art, base_art); } #[test] fn voice_frame_roundtrip() { let frame = VoiceFrame { - rvq: RvqFrame { archetype: 7, coarse: [1; 8], fine: [2; 8] }, - phase: super::super::phase::PhaseDescriptor { bytes: [200, 50, 100, 30] }, + rvq: RvqFrame { + archetype: 7, + coarse: [1; 8], + fine: [2; 8], + }, + phase: super::super::phase::PhaseDescriptor { + bytes: [200, 50, 100, 30], + }, }; let bytes = frame.to_bytes(); assert_eq!(bytes.len(), VoiceFrame::BYTE_SIZE); @@ -443,8 +475,7 @@ mod tests { let k = 3; for i in 0..k { for j in 0..k { - assert_eq!(table[i * k + j], table[j * k + i], - "Distance table not symmetric at ({}, {})", i, j); + assert_eq!(table[i * k + j], table[j * k + i], "Distance table not symmetric at ({}, {})", i, j); } } } diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index 1a9a1e18..59429391 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -18,10 +18,10 @@ //! ``` use crate::hpc::amx_matmul::{ - amx_available, TileConfig, tile_loadconfig, tile_zero, - tile_load, tile_store, tile_release, tile_dpbf16ps, vnni_pack_bf16, + amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16, + TileConfig, }; -use crate::simd::{F32x16, bf16_to_f32_batch}; +use crate::simd::{bf16_to_f32_batch, F32x16}; // ═════════════════════════════════════════════════════════════════════ // Public API — safe dispatching wrapper @@ -47,7 +47,9 @@ pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: us let mut b_vnni = vec![0u16; k * 16]; vnni_pack_bf16(b_bf16, &mut b_vnni, k, 16); // SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl. - unsafe { amx_path(a_bf16, &b_vnni, c, k); } + unsafe { + amx_path(a_bf16, &b_vnni, c, k); + } } else { fallback_path(a_bf16, b_bf16, c, k); } @@ -69,8 +71,8 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { // Accumulate over K/32 tile blocks let k_blocks = k / 32; - let a_stride = (k * 2) as usize; // full A row stride in bytes (bf16 = 2B) - let b_stride = 64usize; // VNNI row stride in bytes + let a_stride = (k * 2) as usize; // full A row stride in bytes (bf16 = 2B) + let b_stride = 64usize; // VNNI row stride in bytes for kb in 0..k_blocks { let a_ptr = a_bf16.as_ptr().add(kb * 32) as *const u8; @@ -103,11 +105,13 @@ fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { // We gather the column into a stack-sized buffer once per (i,j) pair to hit // the chunks_exact(16) + mul_add fast path on contiguous memory. for i in 0..16 { - let a_row = &a_f32[i * k .. i * k + k]; + let a_row = &a_f32[i * k..i * k + k]; for j in 0..16 { // Stream the column into a contiguous buffer let mut col = vec![0.0f32; k]; - for kk in 0..k { col[kk] = b_f32[kk * 16 + j]; } + for kk in 0..k { + col[kk] = b_f32[kk * 16 + j]; + } // Accumulate via F32x16::mul_add (FMA) let mut acc = F32x16::splat(0.0); @@ -128,7 +132,7 @@ fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { #[cfg(test)] mod tests { use super::*; - use crate::simd::{f32_to_bf16_batch, bf16_to_f32_batch}; + use crate::simd::{bf16_to_f32_batch, f32_to_bf16_batch}; /// Scalar BF16 reference (f32-accumulated) — ground truth. fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) { @@ -150,12 +154,11 @@ mod tests { let mut a_f32 = vec![0.0f32; 16 * k]; let mut b_f32 = vec![0.0f32; k * 16]; for i in 0..a_f32.len() { - a_f32[i] = (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 - / 2147483648.0).clamp(-1.0, 1.0); + a_f32[i] = + (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); } for i in 0..b_f32.len() { - b_f32[i] = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 - / 2147483648.0).clamp(-1.0, 1.0); + b_f32[i] = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); } let mut a_bf16 = vec![0u16; a_f32.len()]; let mut b_bf16 = vec![0u16; b_f32.len()]; @@ -178,7 +181,9 @@ mod tests { let mut max_err = 0.0f32; for i in 0..(16 * 16) { let e = (c_fb[i] - c_ref[i]).abs(); - if e > max_err { max_err = e; } + if e > max_err { + max_err = e; + } } assert!(max_err < 1e-3, "fallback vs scalar ref max_err = {}", max_err); } @@ -193,6 +198,8 @@ mod tests { let mut c = vec![0.0f32; 16 * 16]; bf16_tile_gemm_16x16(&a, &b, &mut c, k); // All zeros × all zeros = 0 - for v in c.iter() { assert_eq!(*v, 0.0); } + for v in c.iter() { + assert_eq!(*v, 0.0); + } } } diff --git a/src/hpc/bf16_truth.rs b/src/hpc/bf16_truth.rs index f25cb13e..de2313a0 100644 --- a/src/hpc/bf16_truth.rs +++ b/src/hpc/bf16_truth.rs @@ -44,15 +44,26 @@ impl BF16Weights { assert!( max_per_elem <= 65535, "BF16Weights overflow: sign({}) + 8*exp({}) + 7*man({}) = {} > 65535", - sign, exponent, mantissa, max_per_elem + sign, + exponent, + mantissa, + max_per_elem ); - Self { sign, exponent, mantissa } + Self { + sign, + exponent, + mantissa, + } } } impl Default for BF16Weights { fn default() -> Self { - Self { sign: 256, exponent: 16, mantissa: 1 } + Self { + sign: 256, + exponent: 16, + mantissa: 1, + } } } @@ -146,7 +157,10 @@ pub struct AwarenessThresholds { impl Default for AwarenessThresholds { fn default() -> Self { - Self { exp_spread_limit: 2, noise_mantissa_bits: 5 } + Self { + exp_spread_limit: 2, + noise_mantissa_bits: 5, + } } } @@ -187,7 +201,10 @@ impl PackedQualia { /// Zero-initialized qualia point. pub fn zero() -> Self { - Self { resonance: [0i8; 16], scalar: [0u8; 2] } + Self { + resonance: [0i8; 16], + scalar: [0u8; 2], + } } /// Decode the BF16 scalar back to f32. @@ -245,9 +262,7 @@ pub fn bf16_hamming_scalar(a: &[u8], b: &[u8], weights: &BF16Weights) -> u64 { let man_bits = xor & 0x7F; let man_pop = man_bits.count_ones() as u16; - let dist = sign_flip * weights.sign - + exp_pop * weights.exponent - + man_pop * weights.mantissa; + let dist = sign_flip * weights.sign + exp_pop * weights.exponent + man_pop * weights.mantissa; total += dist as u64; } @@ -269,12 +284,7 @@ pub fn bf16_hamming_scalar(a: &[u8], b: &[u8], weights: &BF16Weights) -> u64 { /// # Returns /// /// A [`SuperpositionState`] summarizing all dimensions. -pub fn awareness_classify( - a: &[u8], - b: &[u8], - n_dims: usize, - thresholds: &AwarenessThresholds, -) -> SuperpositionState { +pub fn awareness_classify(a: &[u8], b: &[u8], n_dims: usize, thresholds: &AwarenessThresholds) -> SuperpositionState { assert_eq!(a.len(), b.len(), "awareness_classify: length mismatch"); assert!(a.len() >= n_dims * 2, "awareness_classify: not enough bytes for n_dims"); @@ -388,9 +398,7 @@ pub fn awareness_classify( /// assert_ne!(packed, 0); /// ``` pub fn bf16_from_projections( - bands: &[super::cascade::Band; 7], - finest_distance: u32, - finest_max: u32, + bands: &[super::cascade::Band; 7], finest_distance: u32, finest_max: u32, direction: super::causality::CausalityDirection, ) -> u16 { use super::cascade::Band; @@ -441,9 +449,7 @@ pub fn bf16_from_projections( /// assert_eq!(exp, 0x7F); /// assert_eq!(man, 0); /// ``` -pub fn bf16_unpack_projections( - packed: u16, -) -> (super::causality::CausalityDirection, u8, u8) { +pub fn bf16_unpack_projections(packed: u16) -> (super::causality::CausalityDirection, u8, u8) { use super::causality::CausalityDirection; let direction = if packed & 0x8000 != 0 { @@ -558,7 +564,10 @@ mod tests { let s = awareness_classify(&data, &data, 4, &t); assert_eq!(s.n_dims, 4); assert!((s.crystallized_pct - 1.0).abs() < 1e-6); - assert!(s.states.iter().all(|st| *st == AwarenessState::Crystallized)); + assert!(s + .states + .iter() + .all(|st| *st == AwarenessState::Crystallized)); } #[test] @@ -609,7 +618,7 @@ mod tests { let (dir, exp, man) = bf16_unpack_projections(packed); assert_eq!(dir, CausalityDirection::Backward); assert_eq!(exp, 0); // no close projections - // mantissa: 500/1000 * 127 = 63 + // mantissa: 500/1000 * 127 = 63 assert_eq!(man, 63); } diff --git a/src/hpc/bgz17_bridge.rs b/src/hpc/bgz17_bridge.rs index de7fe3e8..cf44dd8e 100644 --- a/src/hpc/bgz17_bridge.rs +++ b/src/hpc/bgz17_bridge.rs @@ -124,7 +124,13 @@ fn l1_weighted_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 { let mut d = 0u32; for i in 0..17 { let diff = (a[i] as i32 - b[i] as i32).unsigned_abs(); - let weight = if i == 0 { 20 } else if i < 7 { 3 } else { 1 }; + let weight = if i == 0 { + 20 + } else if i < 7 { + 3 + } else { + 1 + }; d += diff * weight; } d @@ -185,19 +191,18 @@ fn sign_agreement_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 { count } -static SIGN_AGREEMENT_KERNEL: std::sync::LazyLock = - std::sync::LazyLock::new(|| { - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("avx512f") { - return sign_agreement_avx512 as SignAgreementFn; - } - if is_x86_feature_detected!("avx2") { - return sign_agreement_avx2 as SignAgreementFn; - } +static SIGN_AGREEMENT_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return sign_agreement_avx512 as SignAgreementFn; + } + if is_x86_feature_detected!("avx2") { + return sign_agreement_avx2 as SignAgreementFn; } - sign_agreement_scalar as SignAgreementFn - }); + } + sign_agreement_scalar as SignAgreementFn +}); // ============================================================================ // Multi-versioned xor_bind kernel: AVX-512 → AVX2 → scalar. @@ -338,19 +343,18 @@ fn inject_noise_scalar(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] { result } -static INJECT_NOISE_KERNEL: std::sync::LazyLock = - std::sync::LazyLock::new(|| { - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("avx512f") { - return inject_noise_avx512 as InjectNoiseFn; - } - if is_x86_feature_detected!("avx2") { - return inject_noise_avx2 as InjectNoiseFn; - } +static INJECT_NOISE_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return inject_noise_avx512 as InjectNoiseFn; } - inject_noise_scalar as InjectNoiseFn - }); + if is_x86_feature_detected!("avx2") { + return inject_noise_avx2 as InjectNoiseFn; + } + } + inject_noise_scalar as InjectNoiseFn +}); /// SPO triple of Base17 patterns. 102 bytes. #[derive(Clone, Debug, PartialEq, Eq)] @@ -535,19 +539,13 @@ impl SpoBase17 { /// Combined L1 distance (sum of three planes). #[inline] pub fn l1(&self, other: &SpoBase17) -> u32 { - self.subject.l1(&other.subject) - + self.predicate.l1(&other.predicate) - + self.object.l1(&other.object) + self.subject.l1(&other.subject) + self.predicate.l1(&other.predicate) + self.object.l1(&other.object) } /// Per-plane L1 distances. #[inline] pub fn l1_per_plane(&self, other: &SpoBase17) -> (u32, u32, u32) { - ( - self.subject.l1(&other.subject), - self.predicate.l1(&other.predicate), - self.object.l1(&other.object), - ) + (self.subject.l1(&other.subject), self.predicate.l1(&other.predicate), self.object.l1(&other.object)) } } @@ -559,7 +557,11 @@ impl PaletteEdge { /// Deserialize from 3 bytes. pub fn from_bytes(b: &[u8; 3]) -> Self { - PaletteEdge { s_idx: b[0], p_idx: b[1], o_idx: b[2] } + PaletteEdge { + s_idx: b[0], + p_idx: b[1], + o_idx: b[2], + } } } @@ -570,13 +572,17 @@ mod tests { #[test] fn test_golden_coverage() { let mut seen = [false; BASE_DIM]; - for &p in &GOLDEN_POS { seen[p as usize] = true; } + for &p in &GOLDEN_POS { + seen[p as usize] = true; + } assert!(seen.iter().all(|&s| s)); } #[test] fn test_l1_self_zero() { - let a = Base17 { dims: [100, -50, 0, 127, -128, 1, -1, 50, 25, -25, 0, 0, 0, 0, 0, 0, 0] }; + let a = Base17 { + dims: [100, -50, 0, 127, -128, 1, -1, 50, 25, -25, 0, 0, 0, 0, 0, 0, 0], + }; assert_eq!(a.l1(&a), 0); } @@ -589,8 +595,12 @@ mod tests { #[test] fn test_xor_bind_self_inverse() { - let a = Base17 { dims: [100, -200, 300, -400, 500, -600, 700, -800, 900, -1000, 1100, -1200, 1300, -1400, 1500, -1600, 1700] }; - let b = Base17 { dims: [-50, 150, -250, 350, -450, 550, -650, 750, -850, 950, -1050, 1150, -1250, 1350, -1450, 1550, -1650] }; + let a = Base17 { + dims: [100, -200, 300, -400, 500, -600, 700, -800, 900, -1000, 1100, -1200, 1300, -1400, 1500, -1600, 1700], + }; + let b = Base17 { + dims: [-50, 150, -250, 350, -450, 550, -650, 750, -850, 950, -1050, 1150, -1250, 1350, -1450, 1550, -1650], + }; let bound = a.xor_bind(&b); let recovered = bound.xor_bind(&b); assert_eq!(a, recovered, "xor_bind must be its own inverse"); @@ -598,7 +608,9 @@ mod tests { #[test] fn test_xor_bind_identity() { - let a = Base17 { dims: [100, -200, 300, -400, 500, -600, 700, -800, 900, -1000, 1100, -1200, 1300, -1400, 1500, -1600, 1700] }; + let a = Base17 { + dims: [100, -200, 300, -400, 500, -600, 700, -800, 900, -1000, 1100, -1200, 1300, -1400, 1500, -1600, 1700], + }; let zero = Base17::zero(); assert_eq!(a.xor_bind(&zero), a, "xor_bind with zero must be identity"); } @@ -622,14 +634,18 @@ mod tests { #[test] fn test_permute_identity() { - let a = Base17 { dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17] }; + let a = Base17 { + dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17], + }; assert_eq!(a.permute(0), a, "permute(0) must be identity"); assert_eq!(a.permute(BASE_DIM), a, "permute(17) must wrap to identity"); } #[test] fn test_permute_cyclic() { - let a = Base17 { dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17] }; + let a = Base17 { + dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17], + }; let shifted = a.permute(1); for i in 0..BASE_DIM { assert_eq!(shifted.dims[i], a.dims[(i + 1) % BASE_DIM]); @@ -638,7 +654,9 @@ mod tests { #[test] fn test_byte_roundtrip() { - let a = Base17 { dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17] }; + let a = Base17 { + dims: [1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17], + }; let bytes = a.to_bytes(); let b = Base17::from_bytes(&bytes); assert_eq!(a, b); @@ -684,7 +702,11 @@ mod tests { #[test] fn test_palette_edge_roundtrip() { - let pe = PaletteEdge { s_idx: 42, p_idx: 128, o_idx: 255 }; + let pe = PaletteEdge { + s_idx: 42, + p_idx: 128, + o_idx: 255, + }; let bytes = pe.to_bytes(); let pe2 = PaletteEdge::from_bytes(&bytes); assert_eq!(pe, pe2); @@ -718,7 +740,9 @@ mod tests { #[test] fn test_sign_agreement_self() { - let a = Base17 { dims: [100, -50, 30, 0, 10, -20, 40, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] }; + let a = Base17 { + dims: [100, -50, 30, 0, 10, -20, 40, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }; assert_eq!(a.sign_agreement(&a), BASE_DIM as u32); } diff --git a/src/hpc/binding_matrix.rs b/src/hpc/binding_matrix.rs index 574841c9..d16b9c73 100644 --- a/src/hpc/binding_matrix.rs +++ b/src/hpc/binding_matrix.rs @@ -79,17 +79,10 @@ fn permute(fp: &Fingerprint, offset: usize) -> Fingerprint /// assert_eq!(matrix.len(), 4 * 4 * 4); /// ``` pub fn binding_popcount_3d( - x: &Fingerprint<256>, - y: &Fingerprint<256>, - z: &Fingerprint<256>, - resolution: usize, + x: &Fingerprint<256>, y: &Fingerprint<256>, z: &Fingerprint<256>, resolution: usize, ) -> Vec { let total_bits = 256 * 64; // 16384 - let step = if resolution > 1 { - total_bits / resolution - } else { - 0 - }; + let step = if resolution > 1 { total_bits / resolution } else { 0 }; // Pre-compute all permutations (avoid recomputation in inner loops) let x_perms: Vec> = (0..resolution).map(|i| permute(x, i * step)).collect(); @@ -120,9 +113,7 @@ pub fn binding_popcount_3d( /// Returns `(i, j, k, z_score)` sorted by z_score (ascending). /// Points with z_score < 1.0 are in the holographic basin. pub fn find_holographic_sweet_spot( - matrix: &[u32], - resolution: usize, - total_bits: usize, + matrix: &[u32], resolution: usize, total_bits: usize, ) -> Vec<(usize, usize, usize, f64)> { let target = total_bits as f64 / 2.0; let sigma = (total_bits as f64 / 4.0).sqrt(); @@ -152,9 +143,7 @@ pub fn find_holographic_sweet_spot( /// Returns `(i, j, k, z_score)` sorted by z_score (descending). /// Points with z_score > 2.0 are in the discriminative zone. pub fn find_discriminative_spots( - matrix: &[u32], - resolution: usize, - total_bits: usize, + matrix: &[u32], resolution: usize, total_bits: usize, ) -> Vec<(usize, usize, usize, f64)> { let target = total_bits as f64 / 2.0; let sigma = (total_bits as f64 / 4.0).sqrt(); @@ -269,7 +258,8 @@ mod tests { Self(seed) } fn next_u64(&mut self) -> u64 { - self.0 = self.0 + self.0 = self + .0 .wrapping_mul(6364136223846793005) .wrapping_add(1442695040888963407); self.0 >> 1 @@ -354,10 +344,7 @@ mod tests { let matrix = binding_popcount_3d(&x, &y, &z, res); let sweet_spots = find_holographic_sweet_spot(&matrix, res, total_bits); - assert!( - !sweet_spots.is_empty(), - "Should find holographic sweet spots" - ); + assert!(!sweet_spots.is_empty(), "Should find holographic sweet spots"); if sweet_spots.len() > 1 { assert!(sweet_spots[0].3 <= sweet_spots[1].3); @@ -378,11 +365,7 @@ mod tests { assert!(mean > 0.0); assert!(std > 0.0); assert!(min < max); - assert!( - holo_frac > 0.05, - "Expected >5% holographic, got {}", - holo_frac - ); + assert!(holo_frac > 0.05, "Expected >5% holographic, got {}", holo_frac); } #[test] diff --git a/src/hpc/bitwise.rs b/src/hpc/bitwise.rs index 35396574..870314e8 100644 --- a/src/hpc/bitwise.rs +++ b/src/hpc/bitwise.rs @@ -199,11 +199,7 @@ pub fn hamming_batch_raw(query: &[u8], database: &[u8], num_rows: usize, row_byt /// Returns (indices, distances) of the k closest rows in the database. /// Uses `select_nth_unstable` for O(n) partial sort instead of O(n log n). pub fn hamming_top_k_raw( - query: &[u8], - database: &[u8], - num_rows: usize, - row_bytes: usize, - k: usize, + query: &[u8], database: &[u8], num_rows: usize, row_bytes: usize, k: usize, ) -> (Vec, Vec) { let distances = dispatch_hamming_batch(query, database, num_rows, row_bytes); let k = k.min(num_rows); @@ -297,7 +293,8 @@ pub fn masked_popcount_total(words: &[u64], mask: u64) -> u64 { } impl BitwiseOps for ArrayBase -where S: Data +where + S: Data, { fn hamming_distance(&self, other: &Self) -> u64 { if let (Some(a), Some(b)) = (self.as_slice(), other.as_slice()) { @@ -438,12 +435,23 @@ mod tests { /// Generate deterministic pseudo-random test data. fn test_data(n: usize, seed: u8) -> Vec { - (0..n).map(|i| ((i as u8).wrapping_mul(7).wrapping_add(seed).wrapping_mul(13)) ^ (i as u8)).collect() + (0..n) + .map(|i| { + ((i as u8) + .wrapping_mul(7) + .wrapping_add(seed) + .wrapping_mul(13)) + ^ (i as u8) + }) + .collect() } /// Scalar reference — always correct, used to verify SIMD tiers. fn reference_hamming(a: &[u8], b: &[u8]) -> u64 { - a.iter().zip(b.iter()).map(|(&x, &y)| (x ^ y).count_ones() as u64).sum() + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| (x ^ y).count_ones() as u64) + .sum() } fn reference_popcount(a: &[u8]) -> u64 { @@ -553,8 +561,7 @@ mod tests { #[cfg(target_arch = "x86_64")] #[test] fn test_all_tiers_agree() { - let sizes = [0, 1, 3, 7, 15, 16, 31, 32, 33, 63, 64, 65, - 127, 128, 129, 255, 256, 512, 1024, 2048, 4096, 8192]; + let sizes = [0, 1, 3, 7, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256, 512, 1024, 2048, 4096, 8192]; for &n in &sizes { let a = test_data(n, 0x42); @@ -563,18 +570,15 @@ mod tests { if is_x86_feature_detected!("avx2") { let avx2 = unsafe { hamming_avx2(&a, &b) }; - assert_eq!(scalar, avx2, - "scalar≠avx2 at n={}: {} vs {}", n, scalar, avx2); + assert_eq!(scalar, avx2, "scalar≠avx2 at n={}: {} vs {}", n, scalar, avx2); } if is_x86_feature_detected!("avx512bw") { let bw = unsafe { hamming_avx512bw(&a, &b) }; - assert_eq!(scalar, bw, - "scalar≠avx512bw at n={}: {} vs {}", n, scalar, bw); + assert_eq!(scalar, bw, "scalar≠avx512bw at n={}: {} vs {}", n, scalar, bw); } if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") { let vpc = unsafe { crate::backend::kernels_avx512::hamming_distance(&a, &b) }; - assert_eq!(scalar, vpc, - "scalar≠vpopcntdq at n={}: {} vs {}", n, scalar, vpc); + assert_eq!(scalar, vpc, "scalar≠vpopcntdq at n={}: {} vs {}", n, scalar, vpc); } } } @@ -591,13 +595,11 @@ mod tests { if is_x86_feature_detected!("avx512bw") { let bw = unsafe { popcount_avx512bw(&a) }; - assert_eq!(scalar, bw, - "popcount scalar≠avx512bw at n={}: {} vs {}", n, scalar, bw); + assert_eq!(scalar, bw, "popcount scalar≠avx512bw at n={}: {} vs {}", n, scalar, bw); } if is_x86_feature_detected!("avx512vpopcntdq") { let vpc = unsafe { crate::backend::kernels_avx512::popcount(&a) }; - assert_eq!(scalar, vpc, - "popcount scalar≠vpopcntdq at n={}: {} vs {}", n, scalar, vpc); + assert_eq!(scalar, vpc, "popcount scalar≠vpopcntdq at n={}: {} vs {}", n, scalar, vpc); } } } @@ -623,7 +625,8 @@ mod tests { if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") { assert_eq!( unsafe { crate::backend::kernels_avx512::hamming_distance(&a, &b) }, - expected, "vpopcntdq large" + expected, + "vpopcntdq large" ); } } @@ -667,7 +670,9 @@ mod tests { if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") { assert_eq!( unsafe { crate::backend::kernels_avx512::hamming_distance(&a, &b) }, - 0, "vpc identical n={}", n + 0, + "vpc identical n={}", + n ); } } diff --git a/src/hpc/blackboard.rs b/src/hpc/blackboard.rs index e1f1ac2a..82b2ca15 100644 --- a/src/hpc/blackboard.rs +++ b/src/hpc/blackboard.rs @@ -39,9 +39,7 @@ pub struct Blackboard { impl Blackboard { /// Create an empty blackboard. pub fn new() -> Self { - Self { - slots: HashMap::new(), - } + Self { slots: HashMap::new() } } /// Create a blackboard with pre-allocated capacity. @@ -87,16 +85,16 @@ impl Blackboard { /// Allocate (insert or replace) a value of any `'static` type. pub fn alloc(&mut self, key: &str, value: T) { - self.slots.insert( - key.to_string(), - Slot { data: Box::new(value) }, - ); + self.slots + .insert(key.to_string(), Slot { data: Box::new(value) }); } /// Get an immutable reference to a value, returning `None` if the key is /// missing or the type does not match. pub fn get(&self, key: &str) -> Option<&T> { - self.slots.get(key).and_then(|slot| slot.data.downcast_ref::()) + self.slots + .get(key) + .and_then(|slot| slot.data.downcast_ref::()) } /// Get a mutable reference to a value, returning `None` if the key is @@ -290,11 +288,7 @@ impl Blackboard { /// Uses raw pointers internally to circumvent the borrow checker for /// distinct HashMap entries. This is safe because the two keys are verified /// to be different, guaranteeing non-overlapping memory. - pub fn borrow_2_mut_vec_f32( - &mut self, - key_a: &str, - key_b: &str, - ) -> (&mut Vec, &mut Vec) { + pub fn borrow_2_mut_vec_f32(&mut self, key_a: &str, key_b: &str) -> (&mut Vec, &mut Vec) { assert_ne!(key_a, key_b, "borrow_2_mut: keys must be distinct"); // SAFETY: key_a != key_b, so the two mutable references point to // different HashMap entries and cannot alias. @@ -321,11 +315,7 @@ impl Blackboard { /// # Panics /// /// Panics if `key_a == key_b`, or if either key is missing or wrong type. - pub fn borrow_2_mut_vec_f64( - &mut self, - key_a: &str, - key_b: &str, - ) -> (&mut Vec, &mut Vec) { + pub fn borrow_2_mut_vec_f64(&mut self, key_a: &str, key_b: &str) -> (&mut Vec, &mut Vec) { assert_ne!(key_a, key_b, "borrow_2_mut: keys must be distinct"); // SAFETY: key_a != key_b => distinct entries, no aliasing. unsafe { @@ -351,11 +341,7 @@ impl Blackboard { /// # Panics /// /// Panics if `key_a == key_b`, or if either key is missing or wrong type. - pub fn borrow_2_mut_vec_u8( - &mut self, - key_a: &str, - key_b: &str, - ) -> (&mut Vec, &mut Vec) { + pub fn borrow_2_mut_vec_u8(&mut self, key_a: &str, key_b: &str) -> (&mut Vec, &mut Vec) { assert_ne!(key_a, key_b, "borrow_2_mut: keys must be distinct"); // SAFETY: key_a != key_b => distinct entries, no aliasing. unsafe { @@ -382,10 +368,7 @@ impl Blackboard { /// /// Panics if any two keys are equal, or if any key is missing or wrong type. pub fn borrow_3_mut_vec_f32( - &mut self, - key_a: &str, - key_b: &str, - key_c: &str, + &mut self, key_a: &str, key_b: &str, key_c: &str, ) -> (&mut Vec, &mut Vec, &mut Vec) { assert_ne!(key_a, key_b, "borrow_3_mut: keys must be distinct (a == b)"); assert_ne!(key_b, key_c, "borrow_3_mut: keys must be distinct (b == c)"); @@ -421,10 +404,7 @@ impl Blackboard { /// /// Panics if any two keys are equal, or if any key is missing or wrong type. pub fn borrow_3_mut_vec_f64( - &mut self, - key_a: &str, - key_b: &str, - key_c: &str, + &mut self, key_a: &str, key_b: &str, key_c: &str, ) -> (&mut Vec, &mut Vec, &mut Vec) { assert_ne!(key_a, key_b, "borrow_3_mut: keys must be distinct (a == b)"); assert_ne!(key_b, key_c, "borrow_3_mut: keys must be distinct (b == c)"); @@ -749,7 +729,10 @@ mod tests { #[test] fn test_generic_alloc_custom_type() { #[derive(Debug, PartialEq)] - struct Point { x: f32, y: f32 } + struct Point { + x: f32, + y: f32, + } let mut bb = Blackboard::new(); bb.alloc("origin", Point { x: 0.0, y: 0.0 }); diff --git a/src/hpc/blas_level1.rs b/src/hpc/blas_level1.rs index 3eddb5a6..7bdafc07 100644 --- a/src/hpc/blas_level1.rs +++ b/src/hpc/blas_level1.rs @@ -3,8 +3,8 @@ //! Provides vector-vector operations: dot, axpy, scal, nrm2, asum, //! iamax, copy, swap, and element-wise scalar/vector arithmetic. -use crate::imp_prelude::*; use crate::backend::BlasFloat; +use crate::imp_prelude::*; /// BLAS Level 1 operations on 1-D arrays. /// diff --git a/src/hpc/blas_level2.rs b/src/hpc/blas_level2.rs index 0f33f740..27a223bb 100644 --- a/src/hpc/blas_level2.rs +++ b/src/hpc/blas_level2.rs @@ -3,8 +3,8 @@ //! Provides gemv, ger (rank-1 update), symv (symmetric matrix-vector), //! trmv/trsv (triangular multiply/solve). -use crate::imp_prelude::*; use crate::backend::BlasFloat; +use crate::imp_prelude::*; /// Upper or lower triangle specification. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -41,68 +41,38 @@ pub enum Diag { pub trait BlasLevel2 { /// General matrix-vector multiply: y = alpha * A * x + beta * y_init fn blas_gemv( - &self, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, - y_init: &ArrayBase, Ix1>, + &self, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array; /// Rank-1 update: A = alpha * x * y^T + A (returns new array) fn blas_ger( - &self, - alpha: A, - x: &ArrayBase, Ix1>, - y: &ArrayBase, Ix1>, + &self, alpha: A, x: &ArrayBase, Ix1>, y: &ArrayBase, Ix1>, ) -> Array; /// Symmetric matrix-vector multiply: y = alpha * A * x + beta * y_init /// /// Only reads the triangle specified by `uplo`. fn blas_symv( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array; /// Triangular matrix-vector multiply: x = A * x - fn blas_trmv( - &self, - uplo: Uplo, - diag: Diag, - x: &ArrayBase, Ix1>, - ) -> Array; + fn blas_trmv(&self, uplo: Uplo, diag: Diag, x: &ArrayBase, Ix1>) -> Array; /// Triangular solve: solve A * result = x for result - fn blas_trsv( - &self, - uplo: Uplo, - diag: Diag, - x: &ArrayBase, Ix1>, - ) -> Array; + fn blas_trsv(&self, uplo: Uplo, diag: Diag, x: &ArrayBase, Ix1>) -> Array; /// Symmetric rank-1 update: A = alpha * x * x^T + A /// /// Only updates the triangle specified by `uplo`. - fn blas_syr( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - ) -> Array; + fn blas_syr(&self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>) -> Array; /// Symmetric rank-2 update: A = alpha * x * y^T + alpha * y * x^T + A /// /// Only updates the triangle specified by `uplo`. fn blas_syr2( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - y: &ArrayBase, Ix1>, + &self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>, y: &ArrayBase, Ix1>, ) -> Array; /// General banded matrix-vector multiply: y = alpha * A * x + beta * y_init @@ -110,13 +80,7 @@ pub trait BlasLevel2 { /// `kl` is the number of sub-diagonals, `ku` is the number of super-diagonals. /// The matrix `A` is stored in band storage with `kl + ku + 1` rows and `n` columns. fn blas_gbmv( - &self, - m: usize, - kl: usize, - ku: usize, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, m: usize, kl: usize, ku: usize, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array; @@ -125,12 +89,7 @@ pub trait BlasLevel2 { /// `k` is the number of super-diagonals. The matrix is stored in band storage /// with `k + 1` rows and `n` columns. Only the triangle specified by `uplo` is read. fn blas_sbmv( - &self, - uplo: Uplo, - k: usize, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, uplo: Uplo, k: usize, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array; } @@ -141,11 +100,7 @@ where S: Data, { fn blas_gemv( - &self, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, - y_init: &ArrayBase, Ix1>, + &self, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array { let (m, n) = (self.nrows(), self.ncols()); assert_eq!(x.len(), n, "x length must equal number of columns"); @@ -173,10 +128,7 @@ where } fn blas_ger( - &self, - alpha: A, - x: &ArrayBase, Ix1>, - y: &ArrayBase, Ix1>, + &self, alpha: A, x: &ArrayBase, Ix1>, y: &ArrayBase, Ix1>, ) -> Array { let (m, n) = (self.nrows(), self.ncols()); assert_eq!(x.len(), m, "x length must equal number of rows"); @@ -192,11 +144,7 @@ where } fn blas_symv( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array { let n = self.nrows(); @@ -210,10 +158,18 @@ where for j in 0..n { let a_ij = match uplo { Uplo::Upper => { - if j >= i { self[[i, j]] } else { self[[j, i]] } + if j >= i { + self[[i, j]] + } else { + self[[j, i]] + } } Uplo::Lower => { - if j <= i { self[[i, j]] } else { self[[j, i]] } + if j <= i { + self[[i, j]] + } else { + self[[j, i]] + } } }; sum = sum + a_ij * x[j]; @@ -223,12 +179,7 @@ where y } - fn blas_trmv( - &self, - uplo: Uplo, - diag: Diag, - x: &ArrayBase, Ix1>, - ) -> Array { + fn blas_trmv(&self, uplo: Uplo, diag: Diag, x: &ArrayBase, Ix1>) -> Array { let n = self.nrows(); assert_eq!(self.ncols(), n, "Matrix must be square for trmv"); assert_eq!(x.len(), n); @@ -263,12 +214,7 @@ where result } - fn blas_trsv( - &self, - uplo: Uplo, - diag: Diag, - x: &ArrayBase, Ix1>, - ) -> Array { + fn blas_trsv(&self, uplo: Uplo, diag: Diag, x: &ArrayBase, Ix1>) -> Array { let n = self.nrows(); assert_eq!(self.ncols(), n, "Matrix must be square for trsv"); assert_eq!(x.len(), n); @@ -299,12 +245,7 @@ where result } - fn blas_syr( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - ) -> Array { + fn blas_syr(&self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>) -> Array { let n = self.nrows(); assert_eq!(self.ncols(), n, "Matrix must be square for syr"); assert_eq!(x.len(), n); @@ -323,11 +264,7 @@ where } fn blas_syr2( - &self, - uplo: Uplo, - alpha: A, - x: &ArrayBase, Ix1>, - y: &ArrayBase, Ix1>, + &self, uplo: Uplo, alpha: A, x: &ArrayBase, Ix1>, y: &ArrayBase, Ix1>, ) -> Array { let n = self.nrows(); assert_eq!(self.ncols(), n, "Matrix must be square for syr2"); @@ -348,13 +285,7 @@ where } fn blas_gbmv( - &self, - m: usize, - kl: usize, - ku: usize, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, m: usize, kl: usize, ku: usize, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array { let n = x.len(); @@ -380,12 +311,7 @@ where } fn blas_sbmv( - &self, - uplo: Uplo, - k: usize, - alpha: A, - x: &ArrayBase, Ix1>, - beta: A, + &self, uplo: Uplo, k: usize, alpha: A, x: &ArrayBase, Ix1>, beta: A, y_init: &ArrayBase, Ix1>, ) -> Array { let n = x.len(); @@ -529,11 +455,7 @@ mod tests { // row 0 (super-diag): [*, 2, 5] // row 1 (diagonal): [1, 4, 7] // row 2 (sub-diag): [3, 6, *] - let band = array![ - [0.0f64, 2.0, 5.0], - [1.0, 4.0, 7.0], - [3.0, 6.0, 0.0] - ]; + let band = array![[0.0f64, 2.0, 5.0], [1.0, 4.0, 7.0], [3.0, 6.0, 0.0]]; let x = array![1.0f64, 2.0, 3.0]; let y0 = array![0.0f64, 0.0, 0.0]; let y = band.blas_gbmv(3, 1, 1, 1.0, &x, 0.0, &y0); @@ -551,10 +473,7 @@ mod tests { // k=1, upper band storage (2 rows x 3 cols): // row 0 (super-diag): [*, 1, 1] // row 1 (diagonal): [2, 3, 4] - let band = array![ - [0.0f64, 1.0, 1.0], - [2.0, 3.0, 4.0] - ]; + let band = array![[0.0f64, 1.0, 1.0], [2.0, 3.0, 4.0]]; let x = array![1.0f64, 2.0, 3.0]; let y0 = array![0.0f64, 0.0, 0.0]; let y = band.blas_sbmv(Uplo::Upper, 1, 1.0, &x, 0.0, &y0); diff --git a/src/hpc/blas_level3.rs b/src/hpc/blas_level3.rs index 0481dd77..55e1775c 100644 --- a/src/hpc/blas_level3.rs +++ b/src/hpc/blas_level3.rs @@ -3,9 +3,9 @@ //! Provides gemm, syrk (symmetric rank-k update), trsm (triangular solve), //! symm (symmetric matrix multiply). -use crate::imp_prelude::*; -use crate::backend::BlasFloat; use super::blas_level2::Uplo; +use crate::backend::BlasFloat; +use crate::imp_prelude::*; /// Side specification for operations like symm and trsm. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -33,64 +33,27 @@ pub trait BlasLevel3 { /// General matrix multiply: result = alpha * self * B + beta * C_init /// /// If C_init is None, assumes zero initialization. - fn blas_gemm( - &self, - alpha: A, - b: &Self, - beta: A, - ) -> Array; + fn blas_gemm(&self, alpha: A, b: &Self, beta: A) -> Array; /// General matrix multiply with explicit C: C = alpha * self * B + beta * C - fn blas_gemm_into( - &self, - alpha: A, - b: &Self, - beta: A, - c: &mut Array, - ); + fn blas_gemm_into(&self, alpha: A, b: &Self, beta: A, c: &mut Array); /// Symmetric rank-k update: C = alpha * A * A^T + beta * C_init - fn blas_syrk( - &self, - uplo: Uplo, - alpha: A, - beta: A, - c_init: Option<&Self>, - ) -> Array; + fn blas_syrk(&self, uplo: Uplo, alpha: A, beta: A, c_init: Option<&Self>) -> Array; /// Symmetric matrix multiply: C = alpha * A * B + beta * C_init /// /// A is the symmetric matrix (specified by `side`). - fn blas_symm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - b: &Self, - beta: A, - c_init: Option<&Self>, - ) -> Array; + fn blas_symm(&self, side: Side, uplo: Uplo, alpha: A, b: &Self, beta: A, c_init: Option<&Self>) -> Array; /// Triangular matrix-matrix multiply: B = alpha * op(A) * B (Left) /// or B = alpha * B * op(A) (Right). /// /// `a` is the triangular matrix. Only the triangle specified by `uplo` is read. - fn blas_trmm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - a_tri: &Self, - ) -> Array; + fn blas_trmm(&self, side: Side, uplo: Uplo, alpha: A, a_tri: &Self) -> Array; /// Triangular solve (matrix): solve A * X = alpha * B for X - fn blas_trsm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - b: &Self, - ) -> Array; + fn blas_trsm(&self, side: Side, uplo: Uplo, alpha: A, b: &Self) -> Array; } impl BlasLevel3 for ArrayBase @@ -98,21 +61,14 @@ where A: BlasFloat + num_traits::Float + core::ops::AddAssign, S: Data, { - fn blas_gemm( - &self, - alpha: A, - b: &Self, - beta: A, - ) -> Array { + fn blas_gemm(&self, alpha: A, b: &Self, beta: A) -> Array { let (m, k) = (self.nrows(), self.ncols()); let (k2, n) = (b.nrows(), b.ncols()); assert_eq!(k, k2, "Inner dimensions must match for GEMM"); let mut c = Array::zeros((m, n)); - if let (Some(a_s), Some(b_s), Some(c_s)) = - (self.as_slice(), b.as_slice(), c.as_slice_mut()) - { + if let (Some(a_s), Some(b_s), Some(c_s)) = (self.as_slice(), b.as_slice(), c.as_slice_mut()) { A::backend_gemm(m, n, k, alpha, a_s, k, b_s, n, beta, c_s, n); } else { // Fallback for non-contiguous @@ -129,22 +85,14 @@ where c } - fn blas_gemm_into( - &self, - alpha: A, - b: &Self, - beta: A, - c: &mut Array, - ) { + fn blas_gemm_into(&self, alpha: A, b: &Self, beta: A, c: &mut Array) { let (m, k) = (self.nrows(), self.ncols()); let (k2, n) = (b.nrows(), b.ncols()); assert_eq!(k, k2, "Inner dimensions must match for GEMM"); assert_eq!(c.nrows(), m); assert_eq!(c.ncols(), n); - if let (Some(a_s), Some(b_s), Some(c_s)) = - (self.as_slice(), b.as_slice(), c.as_slice_mut()) - { + if let (Some(a_s), Some(b_s), Some(c_s)) = (self.as_slice(), b.as_slice(), c.as_slice_mut()) { A::backend_gemm(m, n, k, alpha, a_s, k, b_s, n, beta, c_s, n); } else { for i in 0..m { @@ -159,13 +107,7 @@ where } } - fn blas_syrk( - &self, - uplo: Uplo, - alpha: A, - beta: A, - c_init: Option<&Self>, - ) -> Array { + fn blas_syrk(&self, uplo: Uplo, alpha: A, beta: A, c_init: Option<&Self>) -> Array { let (m, k) = (self.nrows(), self.ncols()); let mut c = match c_init { Some(ci) => ci.to_owned(), @@ -188,15 +130,7 @@ where c } - fn blas_symm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - b: &Self, - beta: A, - c_init: Option<&Self>, - ) -> Array { + fn blas_symm(&self, side: Side, uplo: Uplo, alpha: A, b: &Self, beta: A, c_init: Option<&Self>) -> Array { let (m, n) = (b.nrows(), b.ncols()); let mut c = match c_init { Some(ci) => ci.to_owned(), @@ -216,10 +150,18 @@ where for p in 0..m { let a_val = match uplo { Uplo::Upper => { - if p >= i { sym[[i, p]] } else { sym[[p, i]] } + if p >= i { + sym[[i, p]] + } else { + sym[[p, i]] + } } Uplo::Lower => { - if p <= i { sym[[i, p]] } else { sym[[p, i]] } + if p <= i { + sym[[i, p]] + } else { + sym[[p, i]] + } } }; sum = sum + a_val * b[[p, j]]; @@ -230,10 +172,18 @@ where for p in 0..n { let a_val = match uplo { Uplo::Upper => { - if p >= j { sym[[j, p]] } else { sym[[p, j]] } + if p >= j { + sym[[j, p]] + } else { + sym[[p, j]] + } } Uplo::Lower => { - if p <= j { sym[[j, p]] } else { sym[[p, j]] } + if p <= j { + sym[[j, p]] + } else { + sym[[p, j]] + } } }; sum = sum + b[[i, p]] * a_val; @@ -246,13 +196,7 @@ where c } - fn blas_trmm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - a_tri: &Self, - ) -> Array { + fn blas_trmm(&self, side: Side, uplo: Uplo, alpha: A, a_tri: &Self) -> Array { let (m, n) = (self.nrows(), self.ncols()); let mut result = Array::zeros((m, n)); let b = self; @@ -314,13 +258,7 @@ where result } - fn blas_trsm( - &self, - side: Side, - uplo: Uplo, - alpha: A, - b: &Self, - ) -> Array { + fn blas_trsm(&self, side: Side, uplo: Uplo, alpha: A, b: &Self) -> Array { let (m, n) = (b.nrows(), b.ncols()); let mut x = b.mapv(|v| alpha * v); let a = self; diff --git a/src/hpc/bnn.rs b/src/hpc/bnn.rs index ac6932a8..50ccbaa6 100644 --- a/src/hpc/bnn.rs +++ b/src/hpc/bnn.rs @@ -33,8 +33,7 @@ pub struct BnnDotResult { #[inline] pub fn bnn_dot(activation: &Fingerprint<256>, weight: &Fingerprint<256>) -> BnnDotResult { let total_bits = Fingerprint::<256>::BITS as u32; - let xor_popcount = - super::bitwise::hamming_distance_raw(activation.as_bytes(), weight.as_bytes()) as u32; + let xor_popcount = super::bitwise::hamming_distance_raw(activation.as_bytes(), weight.as_bytes()) as u32; let match_count = total_bits - xor_popcount; let score = (2.0 * match_count as f32 / total_bits as f32) - 1.0; BnnDotResult { @@ -49,10 +48,9 @@ pub fn bnn_dot_3ch(activation: &GraphHV, weight: &GraphHV) -> BnnDotResult { let total_bits = (Fingerprint::<256>::BITS * 3) as u32; let mut xor_total = 0u32; for ch in 0..3 { - xor_total += super::bitwise::hamming_distance_raw( - activation.channels[ch].as_bytes(), - weight.channels[ch].as_bytes(), - ) as u32; + xor_total += + super::bitwise::hamming_distance_raw(activation.channels[ch].as_bytes(), weight.channels[ch].as_bytes()) + as u32; } let match_count = total_bits - xor_total; let score = (2.0 * match_count as f32 / total_bits as f32) - 1.0; @@ -101,13 +99,7 @@ impl BnnNeuron { } /// Forward pass: compute binary activation from input. - pub fn forward( - &mut self, - input: &Fingerprint<256>, - learn: bool, - learning_rate: f64, - rng: &mut SplitMix64, - ) -> f32 { + pub fn forward(&mut self, input: &Fingerprint<256>, learn: bool, learning_rate: f64, rng: &mut SplitMix64) -> f32 { let dot = bnn_dot(input, &self.state.channels[1]); let pre_activation = dot.score + self.bias; @@ -184,11 +176,7 @@ impl BnnLayer { /// Forward pass: compute all neurons' activations. pub fn forward( - &mut self, - input: &Fingerprint<256>, - learn: bool, - learning_rate: f64, - rng: &mut SplitMix64, + &mut self, input: &Fingerprint<256>, learn: bool, learning_rate: f64, rng: &mut SplitMix64, ) -> Vec { self.neurons .iter_mut() @@ -235,10 +223,7 @@ impl BnnLayer { /// Winner-take-all using CAM index: O(log N) instead of O(N). pub fn winner_cam( - &self, - input: &Fingerprint<256>, - cam: &CamIndex, - shortlist_size: usize, + &self, input: &Fingerprint<256>, cam: &CamIndex, shortlist_size: usize, ) -> Option<(usize, BnnDotResult)> { let query_hv = GraphHV { channels: [input.clone(), input.clone(), input.clone()], @@ -268,9 +253,7 @@ impl BnnLayer { /// Batch XNOR+popcount: compute binary dot products for multiple candidates. pub fn bnn_batch_dot( - query: &Fingerprint<256>, - weights: &[Fingerprint<256>], - top_k: usize, + query: &Fingerprint<256>, weights: &[Fingerprint<256>], top_k: usize, ) -> Vec<(usize, BnnDotResult)> { let mut results: Vec<(usize, BnnDotResult)> = weights .iter() @@ -304,11 +287,7 @@ impl BnnNetwork { /// Forward pass through all layers. pub fn forward( - &mut self, - input: &Fingerprint<256>, - learn: bool, - learning_rate: f64, - rng: &mut SplitMix64, + &mut self, input: &Fingerprint<256>, learn: bool, learning_rate: f64, rng: &mut SplitMix64, ) -> (usize, BnnDotResult) { let mut current_input = input.clone(); @@ -352,10 +331,7 @@ pub struct BnnCascadeResult { /// Cascade-accelerated BNN batch search using K0/K1/K2 pipeline. pub fn bnn_cascade_search( - query: &Fingerprint<256>, - weights: &[Fingerprint<256>], - top_k: usize, - gate: &SliceGate, + query: &Fingerprint<256>, weights: &[Fingerprint<256>], top_k: usize, gate: &SliceGate, ) -> BnnCascadeResult { if weights.is_empty() { return BnnCascadeResult { @@ -370,8 +346,7 @@ pub fn bnn_cascade_search( db_words.extend_from_slice(&w.words); } - let (kernel_matches, stats) = - kernel_pipeline(&query.words, &db_words, n_candidates, SKU_16K_WORDS, gate); + let (kernel_matches, stats) = kernel_pipeline(&query.words, &db_words, n_candidates, SKU_16K_WORDS, gate); let total_bits = Fingerprint::<256>::BITS as u32; let mut results: Vec<(usize, BnnDotResult)> = kernel_matches @@ -414,10 +389,7 @@ pub struct BnnEnergyResult { /// Like `bnn_cascade_search` but also returns EnergyConflict decomposition. pub fn bnn_cascade_search_with_energy( - query: &Fingerprint<256>, - weights: &[Fingerprint<256>], - top_k: usize, - gate: &SliceGate, + query: &Fingerprint<256>, weights: &[Fingerprint<256>], top_k: usize, gate: &SliceGate, ) -> (Vec, PipelineStats) { if weights.is_empty() { return (Vec::new(), PipelineStats::default()); @@ -429,8 +401,7 @@ pub fn bnn_cascade_search_with_energy( db_words.extend_from_slice(&w.words); } - let (kernel_matches, stats) = - kernel_pipeline(&query.words, &db_words, n_candidates, SKU_16K_WORDS, gate); + let (kernel_matches, stats) = kernel_pipeline(&query.words, &db_words, n_candidates, SKU_16K_WORDS, gate); let total_bits = Fingerprint::<256>::BITS as u32; let mut results: Vec = kernel_matches @@ -466,10 +437,7 @@ pub fn bnn_cascade_search_with_energy( /// HDR-cascade-accelerated BNN search. pub fn bnn_hdr_search( - query: &Fingerprint<256>, - weights: &[Fingerprint<256>], - threshold: u64, - top_k: usize, + query: &Fingerprint<256>, weights: &[Fingerprint<256>], threshold: u64, top_k: usize, ) -> Vec<(usize, BnnDotResult)> { use super::cascade::Cascade; @@ -517,11 +485,7 @@ pub fn bnn_hdr_search( // ─── Binary convolution ────────────────────────────── /// 1D binary convolution: slide a kernel over a sequence of Fingerprints. -pub fn bnn_conv1d( - input: &[Fingerprint<256>], - kernel: &Fingerprint<256>, - stride: usize, -) -> Vec { +pub fn bnn_conv1d(input: &[Fingerprint<256>], kernel: &Fingerprint<256>, stride: usize) -> Vec { let stride = stride.max(1); (0..input.len()) .step_by(stride) @@ -540,10 +504,7 @@ pub fn bnn_conv1d_3ch(input: &[GraphHV], kernel: &GraphHV, stride: usize) -> Vec /// Cascade-accelerated 1D convolution. pub fn bnn_conv1d_cascade( - input: &[Fingerprint<256>], - kernel: &Fingerprint<256>, - stride: usize, - gate: &SliceGate, + input: &[Fingerprint<256>], kernel: &Fingerprint<256>, stride: usize, gate: &SliceGate, ) -> BnnCascadeResult { let stride = stride.max(1); let positions: Vec> = (0..input.len()) @@ -596,11 +557,7 @@ mod tests { let a = random_fp(&mut rng); let b = random_fp(&mut rng); let result = bnn_dot(&a, &b); - assert!( - result.score.abs() < 0.05, - "Expected ~0.0 score for random, got {:.4}", - result.score - ); + assert!(result.score.abs() < 0.05, "Expected ~0.0 score for random, got {:.4}", result.score); } #[test] @@ -632,11 +589,7 @@ mod tests { neuron.forward(&input, true, 0.1, &mut rng); } - assert_ne!( - *neuron.plastic(), - initial_plastic, - "Plastic channel should change after learning" - ); + assert_ne!(*neuron.plastic(), initial_plastic, "Plastic channel should change after learning"); } #[test] @@ -861,11 +814,7 @@ mod tests { let result = layer.winner_cam(&target, &cam, 20); assert!(result.is_some()); let (_, dot) = result.unwrap(); - assert!( - dot.score > 0.5, - "Expected high score for planted match, got {:.4}", - dot.score - ); + assert!(dot.score > 0.5, "Expected high score for planted match, got {:.4}", dot.score); } // ─── Convolution tests ────────────────────────────── @@ -878,11 +827,7 @@ mod tests { let results = bnn_conv1d(&sequence, &kernel, 1); assert_eq!(results.len(), 10); - assert!( - (results[5].score - 1.0).abs() < f32::EPSILON, - "Expected 1.0 at pos 5, got {:.4}", - results[5].score - ); + assert!((results[5].score - 1.0).abs() < f32::EPSILON, "Expected 1.0 at pos 5, got {:.4}", results[5].score); } #[test] @@ -934,9 +879,6 @@ mod tests { weights[25] = query.clone(); let results = bnn_hdr_search(&query, &weights, 16384, 10); - assert!( - results.iter().any(|(i, _)| *i == 25), - "HDR search should find exact match at 25" - ); + assert!(results.iter().any(|(i, _)| *i == 25), "HDR search should find exact match at 25"); } } diff --git a/src/hpc/bnn_causal_trajectory.rs b/src/hpc/bnn_causal_trajectory.rs index 1490b45a..70b5715f 100644 --- a/src/hpc/bnn_causal_trajectory.rs +++ b/src/hpc/bnn_causal_trajectory.rs @@ -1,4 +1,6 @@ -#![allow(clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop)] +#![allow( + clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop +)] //! Causal Trajectory Hydration via BNN Instrumentation. //! @@ -35,9 +37,9 @@ //! - Pearl 2009: do-calculus (BPReLU forward ≈ interventional, backward ≈ observational) //! - Czégel et al. 2021: error thresholds for staged assembly via Hold state +use super::bnn_cross_plane::CollapseGate; use super::fingerprint::Fingerprint; use super::kernels::{score_sigma, EnergyConflict, SigmaGate, SignificanceLevel}; -use super::bnn_cross_plane::CollapseGate; use super::bnn_cross_plane::{CrossPlaneVote, HaloType, InferenceMode}; @@ -273,22 +275,15 @@ impl EwmCorrection { } /// Compute with a specific σ-gate (e.g., for SKU-64K). - pub fn compute_with_gate( - prev: &ResonatorSnapshot, - curr: &ResonatorSnapshot, - gate: &SigmaGate, - ) -> Self { + pub fn compute_with_gate(prev: &ResonatorSnapshot, curr: &ResonatorSnapshot, gate: &SigmaGate) -> Self { let s_correction = per_word_popcount(&(&curr.s_est ^ &prev.s_est)); let p_correction = per_word_popcount(&(&curr.p_est ^ &prev.p_est)); let o_correction = per_word_popcount(&(&curr.o_est ^ &prev.o_est)); // Classify per-plane aggregate correction → EwmTier. // Lower aggregate correction = resonator converging = higher tier. - let totals = [ - s_correction.iter().sum::(), - p_correction.iter().sum::(), - o_correction.iter().sum::(), - ]; + let totals = + [s_correction.iter().sum::(), p_correction.iter().sum::(), o_correction.iter().sum::()]; let plane_tiers = totals.map(|total| { let ec = EnergyConflict { conflict: total, @@ -390,44 +385,21 @@ impl CausalSaliency { for word_idx in 0..n_words { // S-plane classify_word_trend( - first.s_correction[word_idx], - last.s_correction[word_idx], - corrections, - word_idx, + first.s_correction[word_idx], last.s_correction[word_idx], corrections, word_idx, 0, // plane index - &mut crystallizing, - &mut dissolving, - &mut contested, - &mut cryst_count, - &mut diss_count, + &mut crystallizing, &mut dissolving, &mut contested, &mut cryst_count, &mut diss_count, &mut cont_count, ); // P-plane classify_word_trend( - first.p_correction[word_idx], - last.p_correction[word_idx], - corrections, - word_idx, - 1, - &mut crystallizing, - &mut dissolving, - &mut contested, - &mut cryst_count, - &mut diss_count, + first.p_correction[word_idx], last.p_correction[word_idx], corrections, word_idx, 1, + &mut crystallizing, &mut dissolving, &mut contested, &mut cryst_count, &mut diss_count, &mut cont_count, ); // O-plane classify_word_trend( - first.o_correction[word_idx], - last.o_correction[word_idx], - corrections, - word_idx, - 2, - &mut crystallizing, - &mut dissolving, - &mut contested, - &mut cryst_count, - &mut diss_count, + first.o_correction[word_idx], last.o_correction[word_idx], corrections, word_idx, 2, + &mut crystallizing, &mut dissolving, &mut contested, &mut cryst_count, &mut diss_count, &mut cont_count, ); } @@ -583,11 +555,7 @@ impl CausalArrow { /// We use popcount of the XOR as unsigned magnitude, then apply the BPReLU /// based on whether the estimate moved TOWARD the codebook (forward) or /// AWAY from it (backward), as indicated by the convergence delta. -fn plane_asymmetry( - prev_est: &Fingerprint<256>, - curr_est: &Fingerprint<256>, - bprelu: &BPReLU, -) -> (f32, f32) { +fn plane_asymmetry(prev_est: &Fingerprint<256>, curr_est: &Fingerprint<256>, bprelu: &BPReLU) -> (f32, f32) { let diff = prev_est ^ curr_est; let changed_bits = diff.popcount() as f32; let total_bits = Fingerprint::<256>::BITS as f32; @@ -682,48 +650,12 @@ impl CausalChain { // Generate all cause→effect links let pairs: [(bool, DominantPlane, u32, bool, DominantPlane); 6] = [ - ( - s_stabilized, - DominantPlane::S, - early.s_activity, - p_responding, - DominantPlane::P, - ), - ( - s_stabilized, - DominantPlane::S, - early.s_activity, - o_responding, - DominantPlane::O, - ), - ( - p_stabilized, - DominantPlane::P, - early.p_activity, - s_responding, - DominantPlane::S, - ), - ( - p_stabilized, - DominantPlane::P, - early.p_activity, - o_responding, - DominantPlane::O, - ), - ( - o_stabilized, - DominantPlane::O, - early.o_activity, - s_responding, - DominantPlane::S, - ), - ( - o_stabilized, - DominantPlane::O, - early.o_activity, - p_responding, - DominantPlane::P, - ), + (s_stabilized, DominantPlane::S, early.s_activity, p_responding, DominantPlane::P), + (s_stabilized, DominantPlane::S, early.s_activity, o_responding, DominantPlane::O), + (p_stabilized, DominantPlane::P, early.p_activity, s_responding, DominantPlane::S), + (p_stabilized, DominantPlane::P, early.p_activity, o_responding, DominantPlane::O), + (o_stabilized, DominantPlane::O, early.o_activity, s_responding, DominantPlane::S), + (o_stabilized, DominantPlane::O, early.o_activity, p_responding, DominantPlane::P), ]; for &(cause_stable, cause_plane, cause_activity, effect_resp, effect_plane) in &pairs { @@ -780,10 +712,7 @@ pub struct HaloTransition { /// /// Compares the cross-plane vote at iteration t-1 with iteration t /// and identifies entries that moved between halo types. -pub fn detect_halo_transitions( - prev: &ResonatorSnapshot, - curr: &ResonatorSnapshot, -) -> Vec { +pub fn detect_halo_transitions(prev: &ResonatorSnapshot, curr: &ResonatorSnapshot) -> Vec { let prev_vote = prev.cross_plane_vote(); let curr_vote = curr.cross_plane_vote(); let mut transitions = Vec::new(); @@ -874,10 +803,7 @@ impl NarsTruth { if w_total < 1e-9 { return NarsTruth::new(0.5, 0.0); } - NarsTruth::new( - (w1 * self.f + w2 * other.f) / w_total, - w_total / (w_total + 1.0), - ) + NarsTruth::new((w1 * self.f + w2 * other.f) / w_total, w_total / (w_total + 1.0)) } /// NARS deduction: = . @@ -1079,7 +1005,7 @@ impl CausalTrajectory { } // Advance shift window every 4 iterations for smoothing. - if (snapshot.iter + 1)% 4 == 0 { + if (snapshot.iter + 1) % 4 == 0 { self.shift_detector.advance_window(); } @@ -1189,10 +1115,7 @@ impl CausalTrajectory { relation: CausalRelation::Contradicts, source_plane: *plane, target_plane: None, - truth: NarsTruth::new( - saliency.contested_count[plane_idx] as f32 / 256.0, - 0.7, - ), + truth: NarsTruth::new(saliency.contested_count[plane_idx] as f32 / 256.0, 0.7), iter: self.snapshots.last().map_or(0, |s| s.iter), inference_mode: None, sigma: None, @@ -1326,12 +1249,7 @@ impl StripeHistogram { /// Total population across all stripes. #[inline] pub fn total(&self) -> u32 { - self.below_1s - + self.s1_to_s15 - + self.s15_to_s2 - + self.s2_to_s25 - + self.s25_to_s3 - + self.above_3s + self.below_1s + self.s1_to_s15 + self.s15_to_s2 + self.s2_to_s25 + self.s25_to_s3 + self.above_3s } /// Classify a σ-value into the appropriate stripe and increment. @@ -1354,14 +1272,7 @@ impl StripeHistogram { /// Convert to array of 6 bin counts [below_1s, ..., above_3s]. pub fn as_array(&self) -> [u32; 6] { - [ - self.below_1s, - self.s1_to_s15, - self.s15_to_s2, - self.s2_to_s25, - self.s25_to_s3, - self.above_3s, - ] + [self.below_1s, self.s1_to_s15, self.s15_to_s2, self.s2_to_s25, self.s25_to_s3, self.above_3s] } /// Center of mass in σ-space: weighted average of bin centers. @@ -1549,7 +1460,9 @@ mod tests { use super::*; struct SplitMix64(u64); impl SplitMix64 { - fn new(seed: u64) -> Self { Self(seed) } + fn new(seed: u64) -> Self { + Self(seed) + } fn next_u64(&mut self) -> u64 { self.0 = self.0.wrapping_add(0x9E3779B97F4A7C15); let mut z = self.0; @@ -1596,18 +1509,9 @@ mod tests { let t2 = NarsTruth::new(0.6, 0.5); let revised = t1.revise(t2); // Revision of equal-weight evidence should give mean frequency - assert!( - (revised.f - 0.7).abs() < 0.01, - "Revised frequency should be ~0.7, got {}", - revised.f - ); + assert!((revised.f - 0.7).abs() < 0.01, "Revised frequency should be ~0.7, got {}", revised.f); // Confidence should increase - assert!( - revised.c > t1.c, - "Revised confidence {} should exceed input {}", - revised.c, - t1.c - ); + assert!(revised.c > t1.c, "Revised confidence {} should exceed input {}", revised.c, t1.c); } #[test] @@ -1615,11 +1519,7 @@ mod tests { let ab = NarsTruth::new(0.9, 0.8); let bc = NarsTruth::new(0.9, 0.8); let ac = ab.deduction(bc); - assert!( - ac.f > 0.7, - "Deduction frequency should be high, got {}", - ac.f - ); + assert!(ac.f > 0.7, "Deduction frequency should be high, got {}", ac.f); assert!(ac.c > 0.0, "Deduction confidence should be > 0"); } @@ -1630,10 +1530,7 @@ mod tests { let mut rng = make_rng(); let snap = make_snapshot(&mut rng, 0, 100); let diff = RifDiff::compute(&snap, &snap); - assert_eq!( - diff.s_activity, 0, - "Identical snapshots should have 0 S activity" - ); + assert_eq!(diff.s_activity, 0, "Identical snapshots should have 0 S activity"); assert_eq!(diff.total_activity(), 0, "Total activity should be 0"); } @@ -1643,10 +1540,7 @@ mod tests { let snap1 = make_snapshot(&mut rng, 0, 100); let snap2 = make_snapshot(&mut rng, 2, 100); let diff = RifDiff::compute(&snap1, &snap2); - assert!( - diff.total_activity() > 0, - "Different snapshots should have activity" - ); + assert!(diff.total_activity() > 0, "Different snapshots should have activity"); } // --- EWM Correction tests --- @@ -1667,10 +1561,7 @@ mod tests { let snap1 = make_snapshot(&mut rng, 0, 100); let snap2 = make_snapshot(&mut rng, 1, 100); let corr = EwmCorrection::compute(&snap1, &snap2); - assert!( - corr.s_total() > 0, - "Different snapshots should have nonzero correction" - ); + assert!(corr.s_total() > 0, "Different snapshots should have nonzero correction"); } // --- Causal Arrow tests --- @@ -1734,11 +1625,7 @@ mod tests { let chain = CausalChain::from_rif_diffs(&[diff1, diff2]); assert!(chain.depth() > 0, "Should detect S→P causal link"); - assert_eq!( - chain.root_cause(), - Some(DominantPlane::S), - "Root cause should be S-plane" - ); + assert_eq!(chain.root_cause(), Some(DominantPlane::S), "Root cause should be S-plane"); assert_eq!(chain.links[0].effect_plane, DominantPlane::P); } @@ -1768,10 +1655,7 @@ mod tests { corr2.iter = 2; let saliency = CausalSaliency::from_ewm_window(&[corr1, corr2]); - assert!( - saliency.crystallizing_count[0] > 0, - "Should detect crystallizing S-plane words" - ); + assert!(saliency.crystallizing_count[0] > 0, "Should detect crystallizing S-plane words"); } // --- Halo Transition tests --- @@ -1893,14 +1777,8 @@ mod tests { snap.delta_s = 5; snap.delta_p = 3; snap.delta_o = 7; - assert!( - snap.converged(10), - "Should be converged when all deltas < threshold" - ); - assert!( - !snap.converged(5), - "Should not converge when delta_o >= threshold" - ); + assert!(snap.converged(10), "Should be converged when all deltas < threshold"); + assert!(!snap.converged(5), "Should not converge when delta_o >= threshold"); } #[test] @@ -1941,10 +1819,7 @@ mod tests { #[test] fn test_ewm_tier_from_significance_discovery() { - assert_eq!( - EwmTier::from(SignificanceLevel::Discovery), - EwmTier::Crystallized - ); + assert_eq!(EwmTier::from(SignificanceLevel::Discovery), EwmTier::Crystallized); } #[test] @@ -1954,18 +1829,12 @@ mod tests { #[test] fn test_ewm_tier_from_significance_evidence() { - assert_eq!( - EwmTier::from(SignificanceLevel::Evidence), - EwmTier::Transitional - ); + assert_eq!(EwmTier::from(SignificanceLevel::Evidence), EwmTier::Transitional); } #[test] fn test_ewm_tier_from_significance_hint() { - assert_eq!( - EwmTier::from(SignificanceLevel::Hint), - EwmTier::Transitional - ); + assert_eq!(EwmTier::from(SignificanceLevel::Hint), EwmTier::Transitional); } #[test] diff --git a/src/hpc/bnn_cross_plane.rs b/src/hpc/bnn_cross_plane.rs index a8c366d0..70ffdcd0 100644 --- a/src/hpc/bnn_cross_plane.rs +++ b/src/hpc/bnn_cross_plane.rs @@ -1,4 +1,6 @@ -#![allow(clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop)] +#![allow( + clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop +)] //! Cross-Plane Partial Binding Algebra for 3D SPO Inference. //! @@ -425,9 +427,7 @@ impl TypedQuery { /// Number of known slots (1 for analogy, 2 for forward/backward/abduction). pub fn known_count(&self) -> usize { - self.subject.is_some() as usize - + self.predicate.is_some() as usize - + self.object.is_some() as usize + self.subject.is_some() as usize + self.predicate.is_some() as usize + self.object.is_some() as usize } } @@ -474,11 +474,7 @@ impl PartialBinding { count += 1; } } - let freq = if count > 0 { - sum_sim / count as f32 - } else { - 0.0 - }; + let freq = if count > 0 { sum_sim / count as f32 } else { 0.0 }; (freq, conf) } } @@ -606,10 +602,7 @@ impl LatticeClimber { /// noise floor, and the resulting per-plane significance levels determine /// the B_3 halo type. pub fn ingest_with_sigma( - &mut self, - candidates: &[(usize, [u32; 3])], - sigma_gate: &SigmaGate, - min_level: SignificanceLevel, + &mut self, candidates: &[(usize, [u32; 3])], sigma_gate: &SigmaGate, min_level: SignificanceLevel, ) { for &(entry_index, plane_distances) in candidates { let binding = classify_with_sigma(entry_index, plane_distances, sigma_gate, min_level); @@ -630,10 +623,7 @@ impl LatticeClimber { /// /// Returns newly promoted bindings. pub fn try_compose( - &mut self, - codebook_s: &[Fingerprint<256>], - codebook_p: &[Fingerprint<256>], - codebook_o: &[Fingerprint<256>], + &mut self, codebook_s: &[Fingerprint<256>], codebook_p: &[Fingerprint<256>], codebook_o: &[Fingerprint<256>], threshold: u32, ) -> Vec { let mut promoted = Vec::new(); @@ -645,9 +635,7 @@ impl LatticeClimber { for pair in &pairs { for fv in &fvs { - let composed = try_compose_pair_and_free( - pair, fv, codebook_s, codebook_p, codebook_o, threshold, - ); + let composed = try_compose_pair_and_free(pair, fv, codebook_s, codebook_p, codebook_o, threshold); if let Some(full) = composed { promoted.push(full); } @@ -670,8 +658,8 @@ impl LatticeClimber { pub fn gate_decision(&self) -> CollapseGate { if !self.full_triples.is_empty() { // Check average confidence of full triples - let avg_conf: f32 = self.full_triples.iter().map(|t| t.confidence).sum::() - / self.full_triples.len() as f32; + let avg_conf: f32 = + self.full_triples.iter().map(|t| t.confidence).sum::() / self.full_triples.len() as f32; if avg_conf > 1.5 { return CollapseGate::Flow; } @@ -718,12 +706,7 @@ pub struct SpoTriple { impl SpoTriple { /// Apply a mutation operator, replacing slot(s) with new fingerprint(s). - pub fn mutate( - &self, - op: MutationOp, - replacement: &Fingerprint<256>, - rng: &mut SplitMix64, - ) -> Self { + pub fn mutate(&self, op: MutationOp, replacement: &Fingerprint<256>, rng: &mut SplitMix64) -> Self { let random_fp = random_fingerprint(rng); match op { MutationOp::MutateS => SpoTriple { @@ -767,11 +750,7 @@ impl SpoTriple { /// XOR-encode into 3D crystal (S^P, P^O, S^O). pub fn encode(&self) -> [Fingerprint<256>; 3] { - [ - &self.subject ^ &self.predicate, - &self.predicate ^ &self.object, - &self.subject ^ &self.object, - ] + [&self.subject ^ &self.predicate, &self.predicate ^ &self.object, &self.subject ^ &self.object] } } @@ -951,10 +930,7 @@ impl PlaneSignificance { /// Each plane's distance is scored against its noise floor, and the resulting /// per-plane significance levels determine the B_3 halo type. pub fn classify_with_sigma( - entry_index: usize, - plane_distances: [u32; 3], - sigma_gate: &SigmaGate, - min_level: SignificanceLevel, + entry_index: usize, plane_distances: [u32; 3], sigma_gate: &SigmaGate, min_level: SignificanceLevel, ) -> PartialBinding { // Score each plane independently let levels: [SignificanceLevel; 3] = plane_distances.map(|dist| { @@ -967,11 +943,7 @@ pub fn classify_with_sigma( score_sigma(&ec, sigma_gate).level }); - let halo = HaloType::from_membership( - levels[0] >= min_level, - levels[1] >= min_level, - levels[2] >= min_level, - ); + let halo = HaloType::from_membership(levels[0] >= min_level, levels[1] >= min_level, levels[2] >= min_level); // Confidence from sigma-significance: sum of per-plane sigma values let sigmas: [f32; 3] = plane_distances.map(|dist| { @@ -1015,9 +987,7 @@ fn popcount_mask(mask: &[u64], n_entries: usize) -> usize { /// Find best matching codebook entry by Hamming distance. fn find_best_match( - estimate: &Fingerprint<256>, - codebook: &[Fingerprint<256>], - mode: InferenceMode, + estimate: &Fingerprint<256>, codebook: &[Fingerprint<256>], mode: InferenceMode, ) -> Option { let mut best_idx = 0; let mut best_dist = u32::MAX; @@ -1052,12 +1022,8 @@ fn random_fingerprint(rng: &mut SplitMix64) -> Fingerprint<256> { /// and if the resulting triple has Hamming distance below threshold on /// the newly filled plane. fn try_compose_pair_and_free( - pair: &PartialBinding, - fv: &PartialBinding, - _codebook_s: &[Fingerprint<256>], - _codebook_p: &[Fingerprint<256>], - _codebook_o: &[Fingerprint<256>], - threshold: u32, + pair: &PartialBinding, fv: &PartialBinding, _codebook_s: &[Fingerprint<256>], _codebook_p: &[Fingerprint<256>], + _codebook_o: &[Fingerprint<256>], threshold: u32, ) -> Option { // SP + O -> Core if pair.halo_type == HaloType::SP && fv.halo_type == HaloType::O { @@ -1143,10 +1109,7 @@ mod tests { assert_eq!(HaloType::from_membership(true, false, false), HaloType::S); assert_eq!(HaloType::from_membership(false, true, false), HaloType::P); assert_eq!(HaloType::from_membership(false, false, true), HaloType::O); - assert_eq!( - HaloType::from_membership(false, false, false), - HaloType::Noise - ); + assert_eq!(HaloType::from_membership(false, false, false), HaloType::Noise); } #[test] @@ -1175,25 +1138,12 @@ mod tests { // Pairwise disjoint: AND of any two should be 0 let masks = [ - vote.core[i], - vote.sp[i], - vote.so[i], - vote.po[i], - vote.s_only[i], - vote.p_only[i], - vote.o_only[i], + vote.core[i], vote.sp[i], vote.so[i], vote.po[i], vote.s_only[i], vote.p_only[i], vote.o_only[i], vote.noise[i], ]; for a in 0..8 { for b in (a + 1)..8 { - assert_eq!( - masks[a] & masks[b], - 0, - "word {}: masks {} and {} overlap", - i, - a, - b - ); + assert_eq!(masks[a] & masks[b], 0, "word {}: masks {} and {} overlap", i, a, b); } } } @@ -1278,10 +1228,7 @@ mod tests { fn test_halo_inference_mode() { assert_eq!(HaloType::SP.inference_mode(), Some(InferenceMode::Forward)); assert_eq!(HaloType::PO.inference_mode(), Some(InferenceMode::Backward)); - assert_eq!( - HaloType::SO.inference_mode(), - Some(InferenceMode::Abduction) - ); + assert_eq!(HaloType::SO.inference_mode(), Some(InferenceMode::Abduction)); assert_eq!(HaloType::S.inference_mode(), Some(InferenceMode::Analogy)); assert_eq!(HaloType::Core.inference_mode(), None); assert_eq!(HaloType::Noise.inference_mode(), None); @@ -1439,33 +1386,21 @@ mod tests { halo_type: HaloType::S, confidence: 0.8, plane_distances: [1000, u32::MAX, u32::MAX], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Noise, - SignificanceLevel::Noise, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Noise, SignificanceLevel::Noise], }, PartialBinding { entry_index: 1, halo_type: HaloType::SP, confidence: 1.5, plane_distances: [1000, 2000, u32::MAX], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - SignificanceLevel::Noise, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Discovery, SignificanceLevel::Noise], }, PartialBinding { entry_index: 2, halo_type: HaloType::Core, confidence: 2.5, plane_distances: [1000, 2000, 3000], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Discovery, SignificanceLevel::Discovery], }, ]; @@ -1490,11 +1425,7 @@ mod tests { halo_type: HaloType::S, confidence: 0.5, plane_distances: [1000, u32::MAX, u32::MAX], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Noise, - SignificanceLevel::Noise, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Noise, SignificanceLevel::Noise], }); assert_eq!(climber.gate_decision(), CollapseGate::Hold); @@ -1504,11 +1435,7 @@ mod tests { halo_type: HaloType::Core, confidence: 2.5, plane_distances: [500, 600, 700], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Discovery, SignificanceLevel::Discovery], }); assert_eq!(climber.gate_decision(), CollapseGate::Flow); } @@ -1524,11 +1451,7 @@ mod tests { 2000, // P-plane: 2000 / 16384 ~ 12.2% distance -> 87.8% similarity u32::MAX, ], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - SignificanceLevel::Noise, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Discovery, SignificanceLevel::Noise], }; let (freq, conf) = binding.nars_truth(); @@ -1596,11 +1519,7 @@ mod tests { halo_type: HaloType::SP, confidence: 1.5, plane_distances: [1000, 1200, u32::MAX], - plane_sigma: [ - SignificanceLevel::Discovery, - SignificanceLevel::Discovery, - SignificanceLevel::Noise, - ], + plane_sigma: [SignificanceLevel::Discovery, SignificanceLevel::Discovery, SignificanceLevel::Noise], }); // O free var at entry 1 @@ -1609,11 +1528,7 @@ mod tests { halo_type: HaloType::O, confidence: 0.7, plane_distances: [u32::MAX, u32::MAX, 800], - plane_sigma: [ - SignificanceLevel::Noise, - SignificanceLevel::Noise, - SignificanceLevel::Discovery, - ], + plane_sigma: [SignificanceLevel::Noise, SignificanceLevel::Noise, SignificanceLevel::Discovery], }); // Create dummy codebooks diff --git a/src/hpc/byte_scan.rs b/src/hpc/byte_scan.rs index 4876a2cb..4f692cc6 100644 --- a/src/hpc/byte_scan.rs +++ b/src/hpc/byte_scan.rs @@ -256,17 +256,26 @@ pub struct NbtSchemaEntry { impl NbtSchemaEntry { /// Create a schema entry for a named compound tag. pub fn compound(name: &str) -> Self { - Self { tag_id: NbtTagId::Compound, name: name.as_bytes().to_vec() } + Self { + tag_id: NbtTagId::Compound, + name: name.as_bytes().to_vec(), + } } /// Create a schema entry for a named list tag. pub fn list(name: &str) -> Self { - Self { tag_id: NbtTagId::List, name: name.as_bytes().to_vec() } + Self { + tag_id: NbtTagId::List, + name: name.as_bytes().to_vec(), + } } /// Create a schema entry for any tag type with given name. pub fn new(tag_id: NbtTagId, name: &str) -> Self { - Self { tag_id, name: name.as_bytes().to_vec() } + Self { + tag_id, + name: name.as_bytes().to_vec(), + } } } @@ -353,11 +362,11 @@ pub fn nbt_schema_scan(data: &[u8], schema: &[NbtSchemaEntry]) -> Vec Vec> { - buffers.iter().map(|buf| nbt_schema_scan(buf, schema)).collect() +pub fn nbt_schema_scan_batch(buffers: &[&[u8]], schema: &[NbtSchemaEntry]) -> Vec> { + buffers + .iter() + .map(|buf| nbt_schema_scan(buf, schema)) + .collect() } // --------------------------------------------------------------------------- @@ -385,11 +394,7 @@ mod tests { // Use a buffer that exercises both SIMD and scalar tail. let buf: Vec = (0..200).map(|i| (i % 7) as u8).collect(); for needle in 0..7u8 { - assert_eq!( - byte_find_all(&buf, needle), - naive_byte_find_all(&buf, needle), - "mismatch for needle {needle}" - ); + assert_eq!(byte_find_all(&buf, needle), naive_byte_find_all(&buf, needle), "mismatch for needle {needle}"); } } @@ -397,11 +402,7 @@ mod tests { fn test_byte_count_matches_naive() { let buf: Vec = (0..200).map(|i| (i % 7) as u8).collect(); for needle in 0..7u8 { - assert_eq!( - byte_count(&buf, needle), - naive_byte_count(&buf, needle), - "mismatch for needle {needle}" - ); + assert_eq!(byte_count(&buf, needle), naive_byte_count(&buf, needle), "mismatch for needle {needle}"); } } @@ -520,10 +521,7 @@ mod tests { data.extend_from_slice(b"BlockEntities"); data.extend_from_slice(&[0; 5]); - let schema = vec![ - NbtSchemaEntry::compound("Entities"), - NbtSchemaEntry::list("BlockEntities"), - ]; + let schema = vec![NbtSchemaEntry::compound("Entities"), NbtSchemaEntry::list("BlockEntities")]; let matches = nbt_schema_scan(&data, &schema); assert_eq!(matches.len(), 2); assert_eq!(matches[0].tag_offset, 0); diff --git a/src/hpc/cam_index.rs b/src/hpc/cam_index.rs index 1f7fa05c..78dd02bf 100644 --- a/src/hpc/cam_index.rs +++ b/src/hpc/cam_index.rs @@ -57,21 +57,13 @@ impl GraphHV { /// Create a zero-initialized GraphHV. pub fn zero() -> Self { Self { - channels: [ - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - ], + channels: [Fingerprint::<256>::zero(), Fingerprint::<256>::zero(), Fingerprint::<256>::zero()], } } /// Create a random GraphHV from a PRNG. pub fn random(rng: &mut SplitMix64) -> Self { - let mut channels = [ - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - ]; + let mut channels = [Fingerprint::<256>::zero(), Fingerprint::<256>::zero(), Fingerprint::<256>::zero()]; for ch in &mut channels { for w in ch.words.iter_mut() { *w = rng.next_u64(); @@ -90,11 +82,7 @@ impl GraphHV { impl std::fmt::Debug for GraphHV { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "GraphHV[S:{:?}, P:{:?}, O:{:?}]", - self.channels[0], self.channels[1], self.channels[2] - ) + write!(f, "GraphHV[S:{:?}, P:{:?}, O:{:?}]", self.channels[0], self.channels[1], self.channels[2]) } } @@ -140,11 +128,7 @@ impl LshProjector { fn new(rng: &mut SplitMix64, sample_size: usize) -> Self { let mut masks = Vec::with_capacity(64); for _ in 0..64 { - let mut ch_masks = [ - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - ]; + let mut ch_masks = [Fingerprint::<256>::zero(), Fingerprint::<256>::zero(), Fingerprint::<256>::zero()]; for _ in 0..sample_size { let ch = (rng.next_u64() % 3) as usize; let word = (rng.next_u64() % 256) as usize; @@ -409,20 +393,14 @@ mod tests { for ch in 0..3 { for w in 0..256 { // Flip ~5% of bits: AND 4 randoms = ~6.25% kill rate - let kill = flip_rng.next_u64() - & flip_rng.next_u64() - & flip_rng.next_u64() - & flip_rng.next_u64(); + let kill = flip_rng.next_u64() & flip_rng.next_u64() & flip_rng.next_u64() & flip_rng.next_u64(); noisy.channels[ch].words[w] ^= kill; } } let results = cam.query(&noisy, 10); let found = results.iter().any(|h| h.index == original_idx); - assert!( - found, - "Similar prototype not found (may be LSH collision miss — non-deterministic)" - ); + assert!(found, "Similar prototype not found (may be LSH collision miss — non-deterministic)"); } #[test] @@ -458,10 +436,7 @@ mod tests { let hv2 = GraphHV::random(&mut rng); let h3 = cam.projectors[0].hash(&hv2); - assert_ne!( - h1, h3, - "Random vectors should produce different hashes (probabilistic)" - ); + assert_ne!(h1, h3, "Random vectors should produce different hashes (probabilistic)"); } #[test] diff --git a/src/hpc/cam_pq.rs b/src/hpc/cam_pq.rs index bf989b49..3684e17b 100644 --- a/src/hpc/cam_pq.rs +++ b/src/hpc/cam_pq.rs @@ -39,12 +39,12 @@ pub const NUM_CENTROIDS: usize = 256; /// Semantic names for the 6 CAM bytes. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CamByte { - Heel = 0, // Coarse category - Branch = 1, // Archetype selection - TwigA = 2, // Shape parameter A - TwigB = 3, // Shape parameter B - Leaf = 4, // Fine detail - Gamma = 5, // Euler tension/energy + Heel = 0, // Coarse category + Branch = 1, // Archetype selection + TwigA = 2, // Shape parameter A + TwigB = 3, // Shape parameter B + Leaf = 4, // Fine detail + Gamma = 5, // Euler tension/energy } /// 6-byte CAM fingerprint. @@ -227,14 +227,22 @@ impl DistanceTables { for s in 0..NUM_SUBSPACES { // Gather 16 centroid indices for subspace s let idx_arr = [ - cams[base][s] as i32, cams[base + 1][s] as i32, - cams[base + 2][s] as i32, cams[base + 3][s] as i32, - cams[base + 4][s] as i32, cams[base + 5][s] as i32, - cams[base + 6][s] as i32, cams[base + 7][s] as i32, - cams[base + 8][s] as i32, cams[base + 9][s] as i32, - cams[base + 10][s] as i32, cams[base + 11][s] as i32, - cams[base + 12][s] as i32, cams[base + 13][s] as i32, - cams[base + 14][s] as i32, cams[base + 15][s] as i32, + cams[base][s] as i32, + cams[base + 1][s] as i32, + cams[base + 2][s] as i32, + cams[base + 3][s] as i32, + cams[base + 4][s] as i32, + cams[base + 5][s] as i32, + cams[base + 6][s] as i32, + cams[base + 7][s] as i32, + cams[base + 8][s] as i32, + cams[base + 9][s] as i32, + cams[base + 10][s] as i32, + cams[base + 11][s] as i32, + cams[base + 12][s] as i32, + cams[base + 13][s] as i32, + cams[base + 14][s] as i32, + cams[base + 15][s] as i32, ]; let indices = I32x16::from_array(idx_arr); @@ -277,27 +285,27 @@ impl PackedDatabase { let stroke1: Vec = fingerprints.iter().map(|f| f[0]).collect(); // Stroke 2: HEEL + BRANCH interleaved (2 bytes per candidate) - let stroke2: Vec = fingerprints.iter() - .flat_map(|f| [f[0], f[1]]) - .collect(); + let stroke2: Vec = fingerprints.iter().flat_map(|f| [f[0], f[1]]).collect(); // Stroke 3: full CAM (6 bytes per candidate) - let stroke3: Vec = fingerprints.iter() + let stroke3: Vec = fingerprints + .iter() .flat_map(|f| f.iter().copied()) .collect(); - PackedDatabase { stroke1, stroke2, stroke3, num_candidates: n } + PackedDatabase { + stroke1, + stroke2, + stroke3, + num_candidates: n, + } } /// Cascade query: Stroke 1 → Stroke 2 → Stroke 3. /// /// 99% rejection before full ADC. Scans 1MB instead of 6MB for 1M vectors. pub fn cascade_query( - &self, - dist_tables: &DistanceTables, - heel_threshold: f32, - branch_threshold: f32, - top_k: usize, + &self, dist_tables: &DistanceTables, heel_threshold: f32, branch_threshold: f32, top_k: usize, ) -> Vec<(usize, f32)> { // Stroke 1: scan HEEL bytes (1 byte/candidate) let mut survivors: Vec = Vec::new(); @@ -313,25 +321,28 @@ impl PackedDatabase { for &i in &survivors { let base = i * 2; let dist = dist_tables.tables[0][self.stroke2[base] as usize] - + dist_tables.tables[1][self.stroke2[base + 1] as usize]; + + dist_tables.tables[1][self.stroke2[base + 1] as usize]; if dist < branch_threshold { refined.push(i); } } // Stroke 3: full ADC on refined candidates (6 bytes/candidate) - let mut hits: Vec<(usize, f32)> = refined.iter().map(|&i| { - let base = i * 6; - let cam: CamFingerprint = [ - self.stroke3[base], - self.stroke3[base + 1], - self.stroke3[base + 2], - self.stroke3[base + 3], - self.stroke3[base + 4], - self.stroke3[base + 5], - ]; - (i, dist_tables.distance(&cam)) - }).collect(); + let mut hits: Vec<(usize, f32)> = refined + .iter() + .map(|&i| { + let base = i * 6; + let cam: CamFingerprint = [ + self.stroke3[base], + self.stroke3[base + 1], + self.stroke3[base + 2], + self.stroke3[base + 3], + self.stroke3[base + 4], + self.stroke3[base + 5], + ]; + (i, dist_tables.distance(&cam)) + }) + .collect(); hits.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)); hits.truncate(top_k); @@ -359,11 +370,7 @@ impl PackedDatabase { /// Train codebooks via k-means on subvectors (standard FAISS PQ). /// /// Minimizes reconstruction error: ||x - decode(encode(x))||². -pub fn train_geometric( - vectors: &[Vec], - total_dim: usize, - iterations: usize, -) -> CamCodebook { +pub fn train_geometric(vectors: &[Vec], total_dim: usize, iterations: usize) -> CamCodebook { assert!(!vectors.is_empty(), "need at least one training vector"); assert!(total_dim >= NUM_SUBSPACES, "dimension must be >= 6"); let subspace_dim = total_dim / NUM_SUBSPACES; @@ -372,19 +379,23 @@ pub fn train_geometric( for s in 0..NUM_SUBSPACES { // Extract subvectors for this subspace - let subs: Vec> = vectors.iter() + let subs: Vec> = vectors + .iter() .map(|v| v[s * subspace_dim..(s + 1) * subspace_dim].to_vec()) .collect(); // k-means clustering let centroids = kmeans(&subs, NUM_CENTROIDS.min(subs.len()), subspace_dim, iterations); - codebooks_vec.push(SubspaceCodebook { centroids, subspace_dim }); + codebooks_vec.push(SubspaceCodebook { + centroids, + subspace_dim, + }); } CamCodebook { - codebooks: codebooks_vec.try_into().unwrap_or_else(|v: Vec| { - panic!("expected {} codebooks, got {}", NUM_SUBSPACES, v.len()) - }), + codebooks: codebooks_vec + .try_into() + .unwrap_or_else(|v: Vec| panic!("expected {} codebooks, got {}", NUM_SUBSPACES, v.len())), total_dim, subspace_dim, } @@ -394,12 +405,7 @@ pub fn train_geometric( /// /// Codebooks balance reconstruction error AND semantic separation. /// `labels[i]` is a set of semantic tags for vector `i`. -pub fn train_semantic( - vectors: &[Vec], - labels: &[Vec], - total_dim: usize, - alpha: f32, -) -> CamCodebook { +pub fn train_semantic(vectors: &[Vec], labels: &[Vec], total_dim: usize, alpha: f32) -> CamCodebook { assert_eq!(vectors.len(), labels.len(), "vectors and labels must match"); // Phase 1: geometric initialization @@ -430,8 +436,8 @@ pub fn train_semantic( let cj = cam_j[s] as usize; let dim = codebook.subspace_dim; for d in 0..dim { - let delta = grad * (codebook.codebooks[s].centroids[cj][d] - - codebook.codebooks[s].centroids[ci][d]); + let delta = grad + * (codebook.codebooks[s].centroids[cj][d] - codebook.codebooks[s].centroids[ci][d]); codebook.codebooks[s].centroids[ci][d] += delta * 0.01; codebook.codebooks[s].centroids[cj][d] -= delta * 0.01; } @@ -445,11 +451,7 @@ pub fn train_semantic( } /// Hybrid training: geometric init + semantic fine-tune (convenience wrapper). -pub fn train_hybrid( - vectors: &[Vec], - labels: &[Vec], - total_dim: usize, -) -> CamCodebook { +pub fn train_hybrid(vectors: &[Vec], labels: &[Vec], total_dim: usize) -> CamCodebook { train_semantic(vectors, labels, total_dim, 0.1) } @@ -469,13 +471,7 @@ pub fn train_hybrid( /// drop trailing elements via the scalar `zip` fallback. #[inline(always)] pub fn squared_l2(a: &[f32], b: &[f32]) -> f32 { - assert_eq!( - a.len(), - b.len(), - "squared_l2: input length mismatch ({} vs {})", - a.len(), - b.len(), - ); + assert_eq!(a.len(), b.len(), "squared_l2: input length mismatch ({} vs {})", a.len(), b.len(),); let n = a.len(); // Fast path: exactly 16 elements = one F32x16 lane (most common in CAM-PQ). @@ -514,7 +510,8 @@ pub fn squared_l2(a: &[f32], b: &[f32]) -> f32 { /// L1 distance between two CAM fingerprints. fn cam_l1_distance(a: &CamFingerprint, b: &CamFingerprint) -> u32 { - a.iter().zip(b.iter()) + a.iter() + .zip(b.iter()) .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs()) .sum() } @@ -526,7 +523,11 @@ fn jaccard_similarity(a: &[String], b: &[String]) -> f32 { } let intersection = a.iter().filter(|x| b.contains(x)).count(); let union = a.len() + b.len() - intersection; - if union == 0 { 1.0 } else { intersection as f32 / union as f32 } + if union == 0 { + 1.0 + } else { + intersection as f32 / union as f32 + } } /// Simple k-means clustering (Lloyd's algorithm with farthest-first seeding). @@ -558,7 +559,9 @@ pub fn kmeans(data: &[Vec], k: usize, dim: usize, iterations: usize) -> Vec } } // Pick farthest point - let best = min_dists.iter().enumerate() + let best = min_dists + .iter() + .enumerate() .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal)) .map(|(i, _)| i) .unwrap_or(0); @@ -781,9 +784,7 @@ mod tests { #[test] fn test_train_hybrid() { let vecs = make_test_vectors(100, 24); - let labels: Vec> = (0..100) - .map(|i| vec![format!("cat_{}", i % 5)]) - .collect(); + let labels: Vec> = (0..100).map(|i| vec![format!("cat_{}", i % 5)]).collect(); let codebook = train_hybrid(&vecs, &labels, 24); assert_eq!(codebook.total_dim, 24); @@ -811,8 +812,7 @@ mod tests { // Centroids should be near (0,0) and (10,10) let c0 = ¢roids[0]; let c1 = ¢roids[1]; - let near_origin = (c0[0].abs() < 1.0 && c0[1].abs() < 1.0) - || (c1[0].abs() < 1.0 && c1[1].abs() < 1.0); + let near_origin = (c0[0].abs() < 1.0 && c0[1].abs() < 1.0) || (c1[0].abs() < 1.0 && c1[1].abs() < 1.0); let near_ten = ((c0[0] - 10.0).abs() < 1.0 && (c0[1] - 10.0).abs() < 1.0) || ((c1[0] - 10.0).abs() < 1.0 && (c1[1] - 10.0).abs() < 1.0); assert!(near_origin, "one centroid should be near origin"); diff --git a/src/hpc/cascade.rs b/src/hpc/cascade.rs index cc5f6535..c41a6034 100644 --- a/src/hpc/cascade.rs +++ b/src/hpc/cascade.rs @@ -5,8 +5,8 @@ //! //! Extracted from rustynum-core/hdr.rs — the cascade algorithm and types. -use super::bitwise; use super::bf16_truth::BF16Weights; +use super::bitwise; /// A ranked hit from the HDR cascade search. #[derive(Debug, Clone)] @@ -70,15 +70,27 @@ impl PartialEq for PreciseMode { match (self, other) { (Self::Off, Self::Off) => true, (Self::Vnni, Self::Vnni) => true, - (Self::F32 { scale: s1, zero_point: z1 }, Self::F32 { scale: s2, zero_point: z2 }) => { - s1.to_bits() == s2.to_bits() && z1 == z2 - } - (Self::BF16 { scale: s1, zero_point: z1 }, Self::BF16 { scale: s2, zero_point: z2 }) => { - s1.to_bits() == s2.to_bits() && z1 == z2 - } - (Self::DeltaXor { delta_weight: w1 }, Self::DeltaXor { delta_weight: w2 }) => { - w1.to_bits() == w2.to_bits() - } + ( + Self::F32 { + scale: s1, + zero_point: z1, + }, + Self::F32 { + scale: s2, + zero_point: z2, + }, + ) => s1.to_bits() == s2.to_bits() && z1 == z2, + ( + Self::BF16 { + scale: s1, + zero_point: z1, + }, + Self::BF16 { + scale: s2, + zero_point: z2, + }, + ) => s1.to_bits() == s2.to_bits() && z1 == z2, + (Self::DeltaXor { delta_weight: w1 }, Self::DeltaXor { delta_weight: w2 }) => w1.to_bits() == w2.to_bits(), (Self::BF16Hamming { weights: w1 }, Self::BF16Hamming { weights: w2 }) => w1 == w2, _ => false, } @@ -98,16 +110,28 @@ pub struct Cascade { impl Cascade { /// Current distribution mean (Welford online estimate). - pub fn mu(&self) -> f64 { self.mu } + pub fn mu(&self) -> f64 { + self.mu + } /// Current distribution standard deviation (Welford online estimate). - pub fn sigma(&self) -> f64 { self.sigma } + pub fn sigma(&self) -> f64 { + self.sigma + } /// Number of observations processed. - pub fn observations(&self) -> usize { self.observations } + pub fn observations(&self) -> usize { + self.observations + } pub fn from_threshold(threshold: u64, vec_bytes: usize) -> Self { - Self { threshold, vec_bytes, mu: 0.0, sigma: 0.0, observations: 0 } + Self { + threshold, + vec_bytes, + mu: 0.0, + sigma: 0.0, + observations: 0, + } } pub fn calibrate(distances: &[u32], vec_bytes: usize) -> Self { @@ -116,10 +140,23 @@ impl Cascade { } let n = distances.len() as f64; let mu = distances.iter().map(|&d| d as f64).sum::() / n; - let var = distances.iter().map(|&d| { let diff = d as f64 - mu; diff * diff }).sum::() / n; + let var = distances + .iter() + .map(|&d| { + let diff = d as f64 - mu; + diff * diff + }) + .sum::() + / n; let sigma = var.sqrt(); let threshold = (mu + 3.0 * sigma) as u64; - Self { threshold, vec_bytes, mu, sigma, observations: distances.len() } + Self { + threshold, + vec_bytes, + mu, + sigma, + observations: distances.len(), + } } pub fn expose(&self, distance: u32) -> Band { @@ -178,13 +215,7 @@ impl Cascade { } /// Run the full 3-stroke cascade query. - pub fn query( - &self, - query: &[u8], - database: &[u8], - vec_bytes: usize, - num_vectors: usize, - ) -> Vec { + pub fn query(&self, query: &[u8], database: &[u8], vec_bytes: usize, num_vectors: usize) -> Vec { assert_eq!(query.len(), vec_bytes); assert_eq!(database.len(), vec_bytes * num_vectors); @@ -232,7 +263,14 @@ impl Cascade { } else { let var: f64 = { let mu: f64 = warmup_dists.iter().map(|&d| d as f64).sum::() / warmup_n as f64; - warmup_dists.iter().map(|&d| { let diff = d as f64 - mu; diff * diff }).sum::() / warmup_n as f64 + warmup_dists + .iter() + .map(|&d| { + let diff = d as f64 - mu; + diff * diff + }) + .sum::() + / warmup_n as f64 }; var.sqrt() }; @@ -280,11 +318,7 @@ impl Cascade { /// Because CLAM already provides geometrically tight candidates, Stroke 1 /// is partially redundant -- we skip directly to full Hamming verification. pub fn query_candidates( - &self, - query: &[u8], - database: &[u8], - vec_bytes: usize, - candidate_indices: &[(usize, u64)], + &self, query: &[u8], database: &[u8], vec_bytes: usize, candidate_indices: &[(usize, u64)], ) -> Vec { let threshold = self.threshold; let mut results = Vec::with_capacity(candidate_indices.len()); @@ -316,12 +350,7 @@ impl Cascade { /// Run the full 3-stroke cascade query with precision scoring (Stroke 3). pub fn query_precise( - &self, - query: &[u8], - database: &[u8], - vec_bytes: usize, - num_vectors: usize, - precise_mode: PreciseMode, + &self, query: &[u8], database: &[u8], vec_bytes: usize, num_vectors: usize, precise_mode: PreciseMode, ) -> Vec { let mut results = self.query(query, database, vec_bytes, num_vectors); @@ -380,11 +409,7 @@ fn bf16_hamming_scalar(a: &[u8], b: &[u8], weights: &BF16Weights) -> u64 { /// /// Sorts by precise distance descending (most similar first). fn apply_precision_tier( - query: &[u8], - database: &[u8], - vec_bytes: usize, - finalists: &mut [RankedHit], - precise_mode: PreciseMode, + query: &[u8], database: &[u8], vec_bytes: usize, finalists: &mut [RankedHit], precise_mode: PreciseMode, ) { match precise_mode { PreciseMode::Off => return, @@ -471,8 +496,7 @@ fn apply_precision_tier( } PreciseMode::BF16Hamming { weights } => { - let max_per_dim = - weights.sign as u64 + 8 * weights.exponent as u64 + 7 * weights.mantissa as u64; + let max_per_dim = weights.sign as u64 + 8 * weights.exponent as u64 + 7 * weights.mantissa as u64; let n_dims = vec_bytes / 2; let max_total = max_per_dim * n_dims as u64; @@ -544,7 +568,15 @@ impl PackedDatabase { } } - Self { stroke1, stroke2, stroke3, num_vectors, s1_bytes, s2_bytes, s3_bytes } + Self { + stroke1, + stroke2, + stroke3, + num_vectors, + s1_bytes, + s2_bytes, + s3_bytes, + } } /// Run cascade query on packed layout. @@ -701,8 +733,14 @@ mod tests { database[2 * vec_bytes..3 * vec_bytes].copy_from_slice(&query); let cascade = Cascade::from_threshold(vec_bytes as u64 * 4, vec_bytes); let results = cascade.query_precise( - &query, &database, vec_bytes, 5, - PreciseMode::F32 { scale: 1.0 / 128.0, zero_point: 128 }, + &query, + &database, + vec_bytes, + 5, + PreciseMode::F32 { + scale: 1.0 / 128.0, + zero_point: 128, + }, ); let exact = results.iter().find(|r| r.index == 2).unwrap(); assert!(!exact.precise.is_nan()); @@ -716,10 +754,7 @@ mod tests { database[1 * vec_bytes..2 * vec_bytes].copy_from_slice(&query); let cascade = Cascade::from_threshold(vec_bytes as u64 * 4, vec_bytes); let weights = BF16Weights::new(256, 16, 1); - let results = cascade.query_precise( - &query, &database, vec_bytes, 5, - PreciseMode::BF16Hamming { weights }, - ); + let results = cascade.query_precise(&query, &database, vec_bytes, 5, PreciseMode::BF16Hamming { weights }); let exact = results.iter().find(|r| r.index == 1).unwrap(); assert!((exact.precise - 1.0).abs() < 1e-6, "exact match should have precise=1.0"); } @@ -750,7 +785,7 @@ mod tests { let vec_bytes = 64; let query = vec![0xFFu8; vec_bytes]; let database = vec![0x00u8; vec_bytes * 5]; // all zeros, max hamming from query - // Hamming(0xFF, 0x00) = 8 bits per byte * 64 bytes = 512 + // Hamming(0xFF, 0x00) = 8 bits per byte * 64 bytes = 512 let cascade = Cascade::from_threshold(100, vec_bytes); // tight threshold let candidates = vec![(0, 512), (1, 512)]; diff --git a/src/hpc/causal_diff.rs b/src/hpc/causal_diff.rs index 28073c09..e2478021 100644 --- a/src/hpc/causal_diff.rs +++ b/src/hpc/causal_diff.rs @@ -13,7 +13,7 @@ //! ``` use super::bgz17_bridge::Base17; -use super::gguf_indexer::{CompressedTensor, LayerType, read_bgz7_file}; +use super::gguf_indexer::{read_bgz7_file, CompressedTensor, LayerType}; use super::nars::NarsTruth; use std::collections::HashMap; @@ -24,29 +24,56 @@ use std::collections::HashMap; /// Which projection within an attention head. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Projection { - Q, K, V, O, - Gate, // MoE router gate - FfnGate, // dense FFN gate - FfnUp, // dense FFN up - FfnDown, // dense FFN down + Q, + K, + V, + O, + Gate, // MoE router gate + FfnGate, // dense FFN gate + FfnUp, // dense FFN up + FfnDown, // dense FFN down Embedding, Other, } /// Classify a tensor name into its projection type. pub fn classify_projection(name: &str) -> Projection { - if name.contains("q_proj") || name.contains("attn_q") { return Projection::Q; } - if name.contains("k_proj") || name.contains("attn_k") { return Projection::K; } - if name.contains("v_proj") || name.contains("attn_v") { return Projection::V; } - if name.contains("o_proj") || name.contains("attn_output") { return Projection::O; } - if name.contains("gate_inp") || name.contains("ffn_gate_inp") { return Projection::Gate; } - if name.contains("gate") && name.contains("exp") { return Projection::FfnGate; } - if name.contains("up") && (name.contains("exp") || name.contains("ffn")) { return Projection::FfnUp; } - if name.contains("down") && (name.contains("exp") || name.contains("ffn")) { return Projection::FfnDown; } - if name.contains("gate") { return Projection::FfnGate; } - if name.contains("up_proj") { return Projection::FfnUp; } - if name.contains("down_proj") { return Projection::FfnDown; } - if name.contains("embed") || name.contains("embd") { return Projection::Embedding; } + if name.contains("q_proj") || name.contains("attn_q") { + return Projection::Q; + } + if name.contains("k_proj") || name.contains("attn_k") { + return Projection::K; + } + if name.contains("v_proj") || name.contains("attn_v") { + return Projection::V; + } + if name.contains("o_proj") || name.contains("attn_output") { + return Projection::O; + } + if name.contains("gate_inp") || name.contains("ffn_gate_inp") { + return Projection::Gate; + } + if name.contains("gate") && name.contains("exp") { + return Projection::FfnGate; + } + if name.contains("up") && (name.contains("exp") || name.contains("ffn")) { + return Projection::FfnUp; + } + if name.contains("down") && (name.contains("exp") || name.contains("ffn")) { + return Projection::FfnDown; + } + if name.contains("gate") { + return Projection::FfnGate; + } + if name.contains("up_proj") { + return Projection::FfnUp; + } + if name.contains("down_proj") { + return Projection::FfnDown; + } + if name.contains("embed") || name.contains("embd") { + return Projection::Embedding; + } Projection::Other } @@ -141,50 +168,82 @@ impl CausalEdge64 { let freq = ((edge.truth.frequency * 1023.0).round() as u64) & 0x3FF; let conf = ((edge.truth.confidence * 1023.0).round() as u64) & 0x3FF; - Self( - (block << 58) - | (proj << 54) - | (verb << 52) - | (row << 36) - | (l1 << 20) - | (freq << 10) - | conf, - ) + Self((block << 58) | (proj << 54) | (verb << 52) | (row << 36) | (l1 << 20) | (freq << 10) | conf) } /// Unpack block number. - #[inline] pub fn block(self) -> u32 { ((self.0 >> 58) & 0x3F) as u32 } + #[inline] + pub fn block(self) -> u32 { + ((self.0 >> 58) & 0x3F) as u32 + } /// Unpack projection type. - #[inline] pub fn projection(self) -> Projection { u4_to_projection(((self.0 >> 54) & 0xF) as u8) } + #[inline] + pub fn projection(self) -> Projection { + u4_to_projection(((self.0 >> 54) & 0xF) as u8) + } /// Unpack verb. - #[inline] pub fn verb(self) -> Verb { - match (self.0 >> 52) & 0x3 { 0 => Verb::Becomes, 1 => Verb::Supports, _ => Verb::Contradicts } + #[inline] + pub fn verb(self) -> Verb { + match (self.0 >> 52) & 0x3 { + 0 => Verb::Becomes, + 1 => Verb::Supports, + _ => Verb::Contradicts, + } } /// Unpack row index. - #[inline] pub fn row_idx(self) -> u32 { ((self.0 >> 36) & 0xFFFF) as u32 } + #[inline] + pub fn row_idx(self) -> u32 { + ((self.0 >> 36) & 0xFFFF) as u32 + } /// Unpack L1 distance. - #[inline] pub fn l1_distance(self) -> u32 { ((self.0 >> 20) & 0xFFFF) as u32 } + #[inline] + pub fn l1_distance(self) -> u32 { + ((self.0 >> 20) & 0xFFFF) as u32 + } /// Unpack NARS frequency. - #[inline] pub fn frequency(self) -> f32 { ((self.0 >> 10) & 0x3FF) as f32 / 1023.0 } + #[inline] + pub fn frequency(self) -> f32 { + ((self.0 >> 10) & 0x3FF) as f32 / 1023.0 + } /// Unpack NARS confidence. - #[inline] pub fn confidence(self) -> f32 { (self.0 & 0x3FF) as f32 / 1023.0 } + #[inline] + pub fn confidence(self) -> f32 { + (self.0 & 0x3FF) as f32 / 1023.0 + } /// Reconstruct NarsTruth. - #[inline] pub fn truth(self) -> NarsTruth { NarsTruth::new(self.frequency(), self.confidence()) } + #[inline] + pub fn truth(self) -> NarsTruth { + NarsTruth::new(self.frequency(), self.confidence()) + } } fn projection_to_u4(p: &Projection) -> u8 { match p { - Projection::Q => 0, Projection::K => 1, Projection::V => 2, Projection::O => 3, - Projection::Gate => 4, Projection::FfnGate => 5, Projection::FfnUp => 6, - Projection::FfnDown => 7, Projection::Embedding => 8, Projection::Other => 9, + Projection::Q => 0, + Projection::K => 1, + Projection::V => 2, + Projection::O => 3, + Projection::Gate => 4, + Projection::FfnGate => 5, + Projection::FfnUp => 6, + Projection::FfnDown => 7, + Projection::Embedding => 8, + Projection::Other => 9, } } fn u4_to_projection(v: u8) -> Projection { match v { - 0 => Projection::Q, 1 => Projection::K, 2 => Projection::V, 3 => Projection::O, - 4 => Projection::Gate, 5 => Projection::FfnGate, 6 => Projection::FfnUp, - 7 => Projection::FfnDown, 8 => Projection::Embedding, _ => Projection::Other, + 0 => Projection::Q, + 1 => Projection::K, + 2 => Projection::V, + 3 => Projection::O, + 4 => Projection::Gate, + 5 => Projection::FfnGate, + 6 => Projection::FfnUp, + 7 => Projection::FfnDown, + 8 => Projection::Embedding, + _ => Projection::Other, } } @@ -261,16 +320,16 @@ pub fn scaffold_to_palette64(edges: &[WeightEdge]) -> ([u64; 64], Vec<(u32, Proj /// BECOMES) emerge from intersection/negation in `scaffold_to_palette3d_layers`. pub fn projection_to_predicate(proj: &Projection) -> usize { match proj { - Projection::Q => 0, // CAUSES (Subject in SPO) - Projection::O => 1, // ENABLES (Object in SPO) - Projection::K => 2, // SUPPORTS (Predicate in SPO — stable = supporting) - Projection::V => 4, // REFINES - Projection::Gate => 3, // CONTRADICTS - Projection::FfnUp => 5, // ABSTRACTS - Projection::FfnDown => 5, // ABSTRACTS (same layer) - Projection::FfnGate => 6, // GROUNDS + Projection::Q => 0, // CAUSES (Subject in SPO) + Projection::O => 1, // ENABLES (Object in SPO) + Projection::K => 2, // SUPPORTS (Predicate in SPO — stable = supporting) + Projection::V => 4, // REFINES + Projection::Gate => 3, // CONTRADICTS + Projection::FfnUp => 5, // ABSTRACTS + Projection::FfnDown => 5, // ABSTRACTS (same layer) + Projection::FfnGate => 6, // GROUNDS Projection::Embedding => 7, // BECOMES - Projection::Other => 7, // BECOMES + Projection::Other => 7, // BECOMES } } @@ -290,16 +349,15 @@ pub fn projection_to_predicate(proj: &Projection) -> usize { /// → .infer(block) → which scaffold blocks fire /// → HHTL cascade → 256×256 fine-grain distances /// ``` -pub fn scaffold_to_heel_planes( - edges: &[WeightEdge], - shift_threshold: f64, -) -> [u64; 8] { +pub fn scaffold_to_heel_planes(edges: &[WeightEdge], shift_threshold: f64) -> [u64; 8] { // Count shifts per (block, projection) let mut block_proj_shifts: HashMap<(u32, u8), (usize, usize)> = HashMap::new(); for edge in edges { let block = edge.block.unwrap_or(0); - if block >= 64 { continue; } // p64 only has 64 rows + if block >= 64 { + continue; + } // p64 only has 64 rows let proj = projection_to_u4(&edge.projection); let entry = block_proj_shifts.entry((block, proj)).or_insert((0, 0)); entry.0 += 1; // shifted @@ -313,7 +371,9 @@ pub fn scaffold_to_heel_planes( let mut planes = [0u64; 8]; for (&(block, proj), &(shifted, _total)) in &block_proj_shifts { - if block >= 64 { continue; } + if block >= 64 { + continue; + } // Map projection to predicate layer let predicate = projection_to_predicate(&u4_to_projection(proj)); if predicate < 8 { @@ -348,10 +408,7 @@ pub fn scaffold_to_heel_planes( /// /// Returns 8 `[u64; 64]` palettes (one per predicate layer) ready for Palette3D. pub fn scaffold_to_palette3d_layers( - edges_v1: &[WeightEdge], - edges_v2: &[WeightEdge], - edges_v1v2: &[WeightEdge], - edges_9b: &[WeightEdge], + edges_v1: &[WeightEdge], edges_v2: &[WeightEdge], edges_v1v2: &[WeightEdge], edges_9b: &[WeightEdge], ) -> [[u64; 64]; 8] { // scaffold_to_heel_planes maps projections → predicate layers: // plane[0] = CAUSES (from Q projections) @@ -374,8 +431,8 @@ pub fn scaffold_to_palette3d_layers( // // Layer 0 CAUSES: base→v1 Q+O topology (what distillation changed) let causes = heels_v1[0] | heels_v1[1]; // Q shifted OR O shifted - // - // Layer 1 ENABLES: base→v2 Q+O topology (what second iteration changed) + // + // Layer 1 ENABLES: base→v2 Q+O topology (what second iteration changed) let enables = heels_v2[0] | heels_v2[1]; // // Layer 4 REFINES: v1→v2 convergence (which heads stabilized) @@ -384,8 +441,8 @@ pub fn scaffold_to_palette3d_layers( // Invert: bits NOT set in v1v2 means the head converged let still_moving = heels_v1v2[0] | heels_v1v2[1]; let refines = causes & !still_moving; // caused in v1, converged by v2 - // - // Layer 5 ABSTRACTS: 9B diff topology (scale-invariant = abstract) + // + // Layer 5 ABSTRACTS: 9B diff topology (scale-invariant = abstract) let abstracts_9b = heels_9b[0] | heels_9b[1]; // Q+O from 9B // ── 4 DEDUCED layers (from intersection/negation) ── @@ -404,14 +461,14 @@ pub fn scaffold_to_palette3d_layers( let becomes = enables & !causes; let heel_bits = [ - causes, // 0 CAUSES - enables, // 1 ENABLES - supports, // 2 SUPPORTS - contradicts, // 3 CONTRADICTS - refines, // 4 REFINES - abstracts_9b,// 5 ABSTRACTS - grounds, // 6 GROUNDS - becomes, // 7 BECOMES + causes, // 0 CAUSES + enables, // 1 ENABLES + supports, // 2 SUPPORTS + contradicts, // 3 CONTRADICTS + refines, // 4 REFINES + abstracts_9b, // 5 ABSTRACTS + grounds, // 6 GROUNDS + becomes, // 7 BECOMES ]; // Expand each HEEL to 64 rows via golden rotation @@ -443,24 +500,21 @@ pub fn scaffold_to_palette3d_layers( /// GROUNDS = SUPPORTS (= SUPPORTS with 2 diffs) /// BECOMES = ABSTRACTS \ CAUSES (9B-only, not in 27B) /// ``` -pub fn scaffold_to_palette3d_from_2_diffs( - edges_27b: &[WeightEdge], - edges_9b: &[WeightEdge], -) -> [[u64; 64]; 8] { +pub fn scaffold_to_palette3d_from_2_diffs(edges_27b: &[WeightEdge], edges_9b: &[WeightEdge]) -> [[u64; 64]; 8] { let heels_27b = scaffold_to_heel_planes(edges_27b, 0.3); let heels_9b = scaffold_to_heel_planes(edges_9b, 0.3); // MEASURED - let causes = heels_27b[0] | heels_27b[1]; // Q|O from base→v2 - let abstracts = heels_9b[0] | heels_9b[1]; // Q|O from 9B + let causes = heels_27b[0] | heels_27b[1]; // Q|O from base→v2 + let abstracts = heels_9b[0] | heels_9b[1]; // Q|O from 9B // DEDUCED - let enables = causes; // single distillation - let supports = causes & abstracts; // both scales agree - let contradicts = causes & !abstracts; // 27B-only (capacity-dependent) - let refines = 0u64; // no v1→v2 - let grounds = supports; // = supports with 2 diffs - let becomes = abstracts & !causes; // 9B-only novel heads + let enables = causes; // single distillation + let supports = causes & abstracts; // both scales agree + let contradicts = causes & !abstracts; // 27B-only (capacity-dependent) + let refines = 0u64; // no v1→v2 + let grounds = supports; // = supports with 2 diffs + let becomes = abstracts & !causes; // 9B-only novel heads let heel_bits = [ causes, // 0 CAUSES @@ -534,11 +588,7 @@ pub struct QualityMap { /// /// The NARS truth per head is the cross-validated frequency from all diffs /// where the head appeared, giving confidence in the classification. -pub fn score_head_quality( - edges_v1: &[WeightEdge], - edges_v2: &[WeightEdge], - edges_9b: &[WeightEdge], -) -> QualityMap { +pub fn score_head_quality(edges_v1: &[WeightEdge], edges_v2: &[WeightEdge], edges_9b: &[WeightEdge]) -> QualityMap { // Collect which (block, proj) pairs appear in each diff let heads_v1 = head_set(edges_v1); let heads_v2 = head_set(edges_v2); @@ -589,7 +639,13 @@ pub fn score_head_quality( heads.insert(key.clone(), (quality, truth)); } - QualityMap { heads, good, bad, uncertain, reverted } + QualityMap { + heads, + good, + bad, + uncertain, + reverted, + } } /// Extract (block, proj) → (frequency, confidence) from edges via cluster_by_head. @@ -619,10 +675,7 @@ fn head_set(edges: &[WeightEdge]) -> HashMap<(u32, String), (f32, f32)> { /// - BAD heads with high conf → suppress (LoRA rank → 0) /// - UNCERTAIN heads → let NARS feedback decide over iterations pub fn scaffold_to_palette3d_quality_filtered( - edges_v1: &[WeightEdge], - edges_v2: &[WeightEdge], - edges_v1v2: &[WeightEdge], - edges_9b: &[WeightEdge], + edges_v1: &[WeightEdge], edges_v2: &[WeightEdge], edges_v1v2: &[WeightEdge], edges_9b: &[WeightEdge], ) -> ([[u64; 64]; 8], QualityMap) { let quality = score_head_quality(edges_v1, edges_v2, edges_9b); @@ -635,7 +688,9 @@ pub fn scaffold_to_palette3d_quality_filtered( let mut good_mask = 0u64; let mut bad_mask = 0u64; for ((block, _proj), (q, _truth)) in &quality.heads { - if *block >= 64 { continue; } + if *block >= 64 { + continue; + } match q { HeadQuality::Good => good_mask |= 1u64 << block, HeadQuality::Bad => bad_mask |= 1u64 << block, @@ -663,14 +718,14 @@ pub fn scaffold_to_palette3d_quality_filtered( // Informational layers (REFINES, ABSTRACTS): GOOD + UNCERTAIN // Tension layers (CONTRADICTS, BECOMES): unfiltered (they ARE the signal) let heel_bits = [ - causes & good_mask, // 0 CAUSES: only good - enables & good_mask, // 1 ENABLES: only good - supports & good_mask, // 2 SUPPORTS: only good - contradicts, // 3 CONTRADICTS: unfiltered - refines & (good_mask | uncertain_mask), // 4 REFINES: good + uncertain - abstracts_9b & (good_mask | uncertain_mask), // 5 ABSTRACTS: good + uncertain - grounds & good_mask, // 6 GROUNDS: only good - becomes, // 7 BECOMES: unfiltered + causes & good_mask, // 0 CAUSES: only good + enables & good_mask, // 1 ENABLES: only good + supports & good_mask, // 2 SUPPORTS: only good + contradicts, // 3 CONTRADICTS: unfiltered + refines & (good_mask | uncertain_mask), // 4 REFINES: good + uncertain + abstracts_9b & (good_mask | uncertain_mask), // 5 ABSTRACTS: good + uncertain + grounds & good_mask, // 6 GROUNDS: only good + becomes, // 7 BECOMES: unfiltered ]; let mut layers = [[0u64; 64]; 8]; @@ -738,12 +793,15 @@ impl NarsHeadBelief { HeadQuality::Bad => LoraAction::Suppress, _ => LoraAction::Explore, }; - beliefs.insert(key.clone(), HeadBelief { - prior: *quality, - truth: *truth, - rounds: 0, - action, - }); + beliefs.insert( + key.clone(), + HeadBelief { + prior: *quality, + truth: *truth, + rounds: 0, + action, + }, + ); } Self { beliefs } } @@ -828,10 +886,7 @@ pub struct VolatilityMap { /// Cross-validates: a head is volatile only if it shifted in multiple diffs. /// Single-diff noise is suppressed by NARS revision across all 4 evidence sources. pub fn build_volatility_map( - edges_v1: &[WeightEdge], - edges_v2: &[WeightEdge], - edges_v1v2: &[WeightEdge], - edges_9b: &[WeightEdge], + edges_v1: &[WeightEdge], edges_v2: &[WeightEdge], edges_v1v2: &[WeightEdge], edges_9b: &[WeightEdge], stats: &[(&str, &DiffStats)], ) -> VolatilityMap { // NARS revision across all diffs per projection @@ -840,9 +895,7 @@ pub fn build_volatility_map( // Per-head volatility: count how many diffs this (block, proj) appears in let mut head_evidence: HashMap<(u32, String), Vec> = HashMap::new(); - for (label, diff_edges) in [ - ("v1", edges_v1), ("v2", edges_v2), ("v1v2", edges_v1v2), ("9b", edges_9b), - ] { + for (label, diff_edges) in [("v1", edges_v1), ("v2", edges_v2), ("v1v2", edges_v1v2), ("9b", edges_9b)] { let clusters = cluster_by_head(diff_edges); for ((block, proj), (count, total, mean_l1)) in &clusters { let f = if *total > 0 { *count as f32 / *total as f32 } else { 0.0 }; @@ -869,7 +922,8 @@ pub fn build_volatility_map( let scaffold_v2 = find_reasoning_scaffold(edges_v2, 0.3); let scaffold_9b = find_reasoning_scaffold(edges_9b, 0.3); - let scale_invariant: Vec = scaffold_v1.iter() + let scale_invariant: Vec = scaffold_v1 + .iter() .filter(|b| scaffold_9b.contains(b)) .cloned() .collect(); @@ -895,11 +949,7 @@ pub fn build_volatility_map( /// For structural restoration, 0.85-0.95 works well — the Q8_0 weights /// still carry approximate information, the palette just sharpens contrast. #[inline] -pub fn apply_palette_overlay( - scores: &mut [f32], - palette_row: u64, - decay: f32, -) { +pub fn apply_palette_overlay(scores: &mut [f32], palette_row: u64, decay: f32) { // Map score positions to 64 palette bins let n = scores.len(); let bin_size = (n + 63) / 64; @@ -922,19 +972,23 @@ pub fn apply_palette_overlay( #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] pub enum PaletteStyle { - Analytical = 0, // tight intersection, contradiction kills - Creative = 1, // wide union, contradiction ignored - Focused = 2, // single causal chain + Analytical = 0, // tight intersection, contradiction kills + Creative = 1, // wide union, contradiction ignored + Focused = 2, // single causal chain Integrative = 3, // majority vote, contradiction as tension - Divergent = 4, // contradiction inverts (fuel) - Meta = 5, // observes the observation + Divergent = 4, // contradiction inverts (fuel) + Meta = 5, // observes the observation } impl PaletteStyle { fn from_u8(v: u8) -> Self { match v { - 0 => Self::Analytical, 1 => Self::Creative, 2 => Self::Focused, - 3 => Self::Integrative, 4 => Self::Divergent, 5 => Self::Meta, + 0 => Self::Analytical, + 1 => Self::Creative, + 2 => Self::Focused, + 3 => Self::Integrative, + 4 => Self::Divergent, + 5 => Self::Meta, _ => Self::Analytical, } } @@ -945,18 +999,15 @@ impl PaletteStyle { /// Format: "PAL8" magic + style(u8) + 8 × 64 × u64 LE = 4101 bytes. /// This is the Cognitive Highway payload: ndarray extracts → PAL8 → lance-graph /// deserializes into `Blumenstrauss::new(planes, semiring)`. -pub fn serialize_palette3d_layers( - layers: &[[u64; 64]; 8], - style: PaletteStyle, - path: &str, -) -> Result<(), String> { +pub fn serialize_palette3d_layers(layers: &[[u64; 64]; 8], style: PaletteStyle, path: &str) -> Result<(), String> { use std::io::Write; let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?; file.write_all(b"PAL8").map_err(|e| e.to_string())?; file.write_all(&[style as u8]).map_err(|e| e.to_string())?; for layer in layers { for &row in layer { - file.write_all(&row.to_le_bytes()).map_err(|e| e.to_string())?; + file.write_all(&row.to_le_bytes()) + .map_err(|e| e.to_string())?; } } Ok(()) @@ -972,7 +1023,8 @@ pub fn deserialize_palette3d_layers(path: &str) -> Result<([[u64; 64]; 8], Palet return Err(format!("bad magic: {:?}", magic)); } let mut style_byte = [0u8; 1]; - file.read_exact(&mut style_byte).map_err(|e| e.to_string())?; + file.read_exact(&mut style_byte) + .map_err(|e| e.to_string())?; let style = PaletteStyle::from_u8(style_byte[0]); let mut layers = [[0u64; 64]; 8]; for layer in &mut layers { @@ -1008,18 +1060,13 @@ pub struct DiffStats { /// /// Returns: (edges, stats) pub fn causal_diff( - base_path: &str, - distilled_path: &str, - l1_threshold: u32, + base_path: &str, distilled_path: &str, l1_threshold: u32, ) -> Result<(Vec, DiffStats), String> { let base_tensors = read_bgz7_file(base_path)?; let dist_tensors = read_bgz7_file(distilled_path)?; // Index distilled tensors by name - let dist_map: HashMap<&str, &CompressedTensor> = dist_tensors - .iter() - .map(|t| (t.name.as_str(), t)) - .collect(); + let dist_map: HashMap<&str, &CompressedTensor> = dist_tensors.iter().map(|t| (t.name.as_str(), t)).collect(); let mut edges = Vec::new(); let mut stats = DiffStats::default(); @@ -1036,8 +1083,7 @@ pub fn causal_diff( // Rows must match if base_t.rows.len() != dist_t.rows.len() { - eprintln!(" WARN: row count mismatch for {}: {} vs {}", - base_t.name, base_t.rows.len(), dist_t.rows.len()); + eprintln!(" WARN: row count mismatch for {}: {} vs {}", base_t.name, base_t.rows.len(), dist_t.rows.len()); continue; } @@ -1076,7 +1122,11 @@ pub fn causal_diff( } } - let mean_l1 = if n_rows > 0 { total_l1 as f64 / n_rows as f64 } else { 0.0 }; + let mean_l1 = if n_rows > 0 { + total_l1 as f64 / n_rows as f64 + } else { + 0.0 + }; let entry = stats.by_projection.entry(proj_key).or_insert((0, 0, 0.0)); entry.0 += shifted; entry.1 += n_rows; @@ -1091,26 +1141,36 @@ pub fn print_diff_summary(label: &str, stats: &DiffStats, edge_count: usize) { eprintln!(); eprintln!("━━━ {} ━━━", label); eprintln!(" Tensors matched: {}, unmatched: {}", stats.tensors_matched, stats.tensors_unmatched); - eprintln!(" Rows: {} compared, {} shifted ({:.1}%), {} stable", - stats.rows_compared, stats.rows_shifted, - if stats.rows_compared > 0 { stats.rows_shifted as f64 / stats.rows_compared as f64 * 100.0 } else { 0.0 }, - stats.rows_stable); + eprintln!( + " Rows: {} compared, {} shifted ({:.1}%), {} stable", + stats.rows_compared, + stats.rows_shifted, + if stats.rows_compared > 0 { + stats.rows_shifted as f64 / stats.rows_compared as f64 * 100.0 + } else { + 0.0 + }, + stats.rows_stable + ); eprintln!(" Edges emitted: {}", edge_count); eprintln!(); // Sort projections by shift percentage let mut projs: Vec<_> = stats.by_projection.iter().collect(); projs.sort_by(|a, b| { - let pct_a = if a.1.1 > 0 { a.1.0 as f64 / a.1.1 as f64 } else { 0.0 }; - let pct_b = if b.1.1 > 0 { b.1.0 as f64 / b.1.1 as f64 } else { 0.0 }; + let pct_a = if a.1 .1 > 0 { a.1 .0 as f64 / a.1 .1 as f64 } else { 0.0 }; + let pct_b = if b.1 .1 > 0 { b.1 .0 as f64 / b.1 .1 as f64 } else { 0.0 }; pct_b.partial_cmp(&pct_a).unwrap() }); eprintln!(" Per projection:"); for (proj, (shifted, total, mean_l1)) in &projs { - let pct = if *total > 0 { *shifted as f64 / *total as f64 * 100.0 } else { 0.0 }; - eprintln!(" {:<12} {:>6}/{:<6} shifted ({:>5.1}%) mean_L1={:.1}", - proj, shifted, total, pct, mean_l1); + let pct = if *total > 0 { + *shifted as f64 / *total as f64 * 100.0 + } else { + 0.0 + }; + eprintln!(" {:<12} {:>6}/{:<6} shifted ({:>5.1}%) mean_L1={:.1}", proj, shifted, total, pct, mean_l1); } } @@ -1130,7 +1190,8 @@ pub fn cluster_by_head(edges: &[WeightEdge]) -> HashMap<(u32, String), (usize, u } } - clusters.into_iter() + clusters + .into_iter() .map(|(k, (count, max_row, total_l1))| { let mean_l1 = if count > 0 { total_l1 as f64 / count as f64 } else { 0.0 }; (k, (count, max_row, mean_l1)) @@ -1147,24 +1208,33 @@ pub fn find_reasoning_scaffold( let mut scaffold_blocks = Vec::new(); // Find all blocks - let blocks: std::collections::BTreeSet = edges.iter() - .filter_map(|e| e.block) - .collect(); + let blocks: std::collections::BTreeSet = edges.iter().filter_map(|e| e.block).collect(); for block in blocks { let q_shift = clusters.get(&(block, "Q".to_string())); let k_shift = clusters.get(&(block, "K".to_string())); let o_shift = clusters.get(&(block, "O".to_string())); - let q_pct = q_shift.map(|(c, t, _)| *c as f64 / *t as f64).unwrap_or(0.0); - let k_pct = k_shift.map(|(c, t, _)| *c as f64 / *t as f64).unwrap_or(0.0); - let o_pct = o_shift.map(|(c, t, _)| *c as f64 / *t as f64).unwrap_or(0.0); + let q_pct = q_shift + .map(|(c, t, _)| *c as f64 / *t as f64) + .unwrap_or(0.0); + let k_pct = k_shift + .map(|(c, t, _)| *c as f64 / *t as f64) + .unwrap_or(0.0); + let o_pct = o_shift + .map(|(c, t, _)| *c as f64 / *t as f64) + .unwrap_or(0.0); // Reasoning scaffold: Q+O shifted, K stable if q_pct > shift_threshold && o_pct > shift_threshold && k_pct < shift_threshold { scaffold_blocks.push(block); - eprintln!(" Block {:>2}: SCAFFOLD Q={:.0}% O={:.0}% K={:.0}%", - block, q_pct * 100.0, o_pct * 100.0, k_pct * 100.0); + eprintln!( + " Block {:>2}: SCAFFOLD Q={:.0}% O={:.0}% K={:.0}%", + block, + q_pct * 100.0, + o_pct * 100.0, + k_pct * 100.0 + ); } } @@ -1179,23 +1249,29 @@ pub fn find_reasoning_scaffold( /// /// For each projection type, integrates evidence from multiple model pairs: /// e.g., 27B_v1, 27B_v2, 9B → revised belief about reasoning scaffold. -pub fn revise_across_diffs( - diff_results: &[(&str, &DiffStats)], -) -> HashMap { +pub fn revise_across_diffs(diff_results: &[(&str, &DiffStats)]) -> HashMap { let mut revised: HashMap = HashMap::new(); for (label, stats) in diff_results { for (proj, (shifted, total, _mean_l1)) in &stats.by_projection { - let f = if *total > 0 { *shifted as f32 / *total as f32 } else { 0.0 }; + let f = if *total > 0 { + *shifted as f32 / *total as f32 + } else { + 0.0 + }; let c = (1.0 - 1.0 / (1.0 + *total as f32)).min(0.99); let evidence = NarsTruth::new(f, c); - let entry = revised.entry(proj.clone()).or_insert(NarsTruth::new(0.5, 0.0)); + let entry = revised + .entry(proj.clone()) + .or_insert(NarsTruth::new(0.5, 0.0)); // NARS revision: integrate new evidence *entry = nars_revision(*entry, evidence); - eprintln!(" [{}] {}: f={:.3} c={:.3} → revised f={:.3} c={:.3}", - label, proj, f, c, entry.frequency, entry.confidence); + eprintln!( + " [{}] {}: f={:.3} c={:.3} → revised f={:.3} c={:.3}", + label, proj, f, c, entry.frequency, entry.confidence + ); } } @@ -1274,8 +1350,7 @@ pub fn extract_gate_topology(bgz7_path: &str) -> Result, }); } - eprintln!(" Gate: {} → {} experts in block {}", - t.name, t.rows.len(), block); + eprintln!(" Gate: {} → {} experts in block {}", t.name, t.rows.len(), block); } Ok(fingerprints) @@ -1285,10 +1360,7 @@ pub fn extract_gate_topology(bgz7_path: &str) -> Result, /// /// `redundancy_threshold`: L1 below which two experts are "structurally interchangeable". /// Suggested: 500 (conservative), 1000 (aggressive). -pub fn cluster_experts( - fingerprints: &[ExpertFingerprint], - redundancy_threshold: u32, -) -> Vec { +pub fn cluster_experts(fingerprints: &[ExpertFingerprint], redundancy_threshold: u32) -> Vec { // Group by block let mut by_block: HashMap> = HashMap::new(); for fp in fingerprints { @@ -1317,13 +1389,19 @@ pub fn cluster_experts( } } - let mean_l1 = if total_pairs > 0 { total_l1 as f64 / total_pairs as f64 } else { 0.0 }; + let mean_l1 = if total_pairs > 0 { + total_l1 as f64 / total_pairs as f64 + } else { + 0.0 + }; // Simple connected-component grouping let mut visited = vec![false; n]; let mut groups = Vec::new(); for start in 0..n { - if visited[start] { continue; } + if visited[start] { + continue; + } let mut group = vec![start]; visited[start] = true; let mut stack = vec![start]; @@ -1341,10 +1419,20 @@ pub fn cluster_experts( } } - eprintln!(" Block {:>2}: {} experts, mean_L1={:.0}, redundant_pairs={}/{} ({:.0}%), groups={}", - block, n, mean_l1, redundant, total_pairs, - if total_pairs > 0 { redundant as f64 / total_pairs as f64 * 100.0 } else { 0.0 }, - groups.len()); + eprintln!( + " Block {:>2}: {} experts, mean_L1={:.0}, redundant_pairs={}/{} ({:.0}%), groups={}", + block, + n, + mean_l1, + redundant, + total_pairs, + if total_pairs > 0 { + redundant as f64 / total_pairs as f64 * 100.0 + } else { + 0.0 + }, + groups.len() + ); clusters.push(ExpertCluster { block: *block, @@ -1365,23 +1453,27 @@ pub fn cluster_experts( /// For each scaffold block (where Q+O shifted), check if the gate /// in that block has high expert redundancy. High redundancy + scaffold /// = the reasoning change works THROUGH the router, not the experts. -pub fn cross_reference_gate_scaffold( - clusters: &[ExpertCluster], - scaffold_blocks: &[u32], -) -> Vec<(u32, bool, f64)> { +pub fn cross_reference_gate_scaffold(clusters: &[ExpertCluster], scaffold_blocks: &[u32]) -> Vec<(u32, bool, f64)> { let mut results = Vec::new(); for block in scaffold_blocks { if let Some(cluster) = clusters.iter().find(|c| c.block == *block) { let redundancy_pct = if cluster.total_pairs > 0 { cluster.redundant_pairs as f64 / cluster.total_pairs as f64 - } else { 0.0 }; + } else { + 0.0 + }; let is_routing_dominated = redundancy_pct > 0.5; results.push((*block, is_routing_dominated, redundancy_pct)); - eprintln!(" Block {:>2}: scaffold={} routing_dominated={} redundancy={:.0}%", - block, true, is_routing_dominated, redundancy_pct * 100.0); + eprintln!( + " Block {:>2}: scaffold={} routing_dominated={} redundancy={:.0}%", + block, + true, + is_routing_dominated, + redundancy_pct * 100.0 + ); } else { // No gate in this block (dense layer, not MoE) results.push((*block, false, 0.0)); @@ -1400,10 +1492,7 @@ pub fn cross_reference_gate_scaffold( /// Uses XOR-analog on Base17: effect.l1(candidate) finds the nearest causal antecedent. /// Science: Pearl (2009), Plate (2003), Squires & Uhler (2023). pub fn reverse_trace( - effect: &super::bgz17_bridge::Base17, - candidates: &[super::bgz17_bridge::Base17], - max_depth: usize, - threshold: u32, + effect: &super::bgz17_bridge::Base17, candidates: &[super::bgz17_bridge::Base17], max_depth: usize, threshold: u32, ) -> Vec<(usize, u32, NarsTruth)> { let mut chain = Vec::new(); let mut current = effect.clone(); @@ -1419,7 +1508,9 @@ pub fn reverse_trace( best_idx = i; } } - if best_dist > threshold || best_dist == u32::MAX { break; } + if best_dist > threshold || best_dist == u32::MAX { + break; + } let frequency = 1.0 - (best_dist as f32 / max_l1); let confidence = (1.0 - 1.0 / (1.0 + chain.len() as f32 + 1.0)).min(0.99); @@ -1483,16 +1574,25 @@ mod tests { #[test] fn test_causal_edge64_all_projections() { for (proj, val) in [ - (Projection::Q, 0), (Projection::K, 1), (Projection::V, 2), (Projection::O, 3), - (Projection::Gate, 4), (Projection::FfnGate, 5), (Projection::FfnUp, 6), - (Projection::FfnDown, 7), (Projection::Embedding, 8), (Projection::Other, 9), + (Projection::Q, 0), + (Projection::K, 1), + (Projection::V, 2), + (Projection::O, 3), + (Projection::Gate, 4), + (Projection::FfnGate, 5), + (Projection::FfnUp, 6), + (Projection::FfnDown, 7), + (Projection::Embedding, 8), + (Projection::Other, 9), ] { let edge = WeightEdge { tensor_name: String::new(), - row_idx: 0, block: Some(0), + row_idx: 0, + block: Some(0), projection: proj.clone(), layer_type: LayerType::Attention, - verb: Verb::Becomes, l1_distance: 0, + verb: Verb::Becomes, + l1_distance: 0, truth: NarsTruth::new(0.5, 0.5), }; let packed = CausalEdge64::pack(&edge); @@ -1502,16 +1602,18 @@ mod tests { #[test] fn test_scaffold_to_palette64() { - let edges: Vec = (0..10).map(|i| WeightEdge { - tensor_name: format!("blk.{}.attn_q.weight", i), - row_idx: i as u32, - block: Some(i as u32), - projection: Projection::Q, - layer_type: LayerType::Attention, - verb: Verb::Becomes, - l1_distance: 500 + i as u32 * 100, - truth: NarsTruth::new(0.8, 0.9), - }).collect(); + let edges: Vec = (0..10) + .map(|i| WeightEdge { + tensor_name: format!("blk.{}.attn_q.weight", i), + row_idx: i as u32, + block: Some(i as u32), + projection: Projection::Q, + layer_type: LayerType::Attention, + verb: Verb::Becomes, + l1_distance: 500 + i as u32 * 100, + truth: NarsTruth::new(0.8, 0.9), + }) + .collect(); let (rows, labels) = scaffold_to_palette64(&edges); assert_eq!(labels.len(), 10); @@ -1545,20 +1647,24 @@ mod tests { for b in 0..8u32 { edges.push(WeightEdge { tensor_name: format!("layers.{}.self_attn.q_proj.weight", b), - row_idx: 0, block: Some(b), + row_idx: 0, + block: Some(b), projection: Projection::Q, layer_type: LayerType::Attention, - verb: Verb::Becomes, l1_distance: 500, + verb: Verb::Becomes, + l1_distance: 500, truth: NarsTruth::new(0.8, 0.9), }); } for b in 2..6u32 { edges.push(WeightEdge { tensor_name: format!("layers.{}.self_attn.o_proj.weight", b), - row_idx: 0, block: Some(b), + row_idx: 0, + block: Some(b), projection: Projection::O, layer_type: LayerType::Attention, - verb: Verb::Becomes, l1_distance: 400, + verb: Verb::Becomes, + l1_distance: 400, truth: NarsTruth::new(0.7, 0.85), }); } @@ -1579,12 +1685,19 @@ mod tests { #[test] fn test_scaffold_to_palette3d_layers() { let make_edges = |blocks: &[u32], proj: Projection| -> Vec { - blocks.iter().map(|&b| WeightEdge { - tensor_name: format!("layers.{}.self_attn.q_proj.weight", b), - row_idx: 0, block: Some(b), projection: proj.clone(), - layer_type: LayerType::Attention, verb: Verb::Becomes, - l1_distance: 500, truth: NarsTruth::new(0.8, 0.9), - }).collect() + blocks + .iter() + .map(|&b| WeightEdge { + tensor_name: format!("layers.{}.self_attn.q_proj.weight", b), + row_idx: 0, + block: Some(b), + projection: proj.clone(), + layer_type: LayerType::Attention, + verb: Verb::Becomes, + l1_distance: 500, + truth: NarsTruth::new(0.8, 0.9), + }) + .collect() }; // v1: Q blocks 0-3, O blocks 1-2 @@ -1688,12 +1801,19 @@ mod tests { #[test] fn test_score_head_quality() { let make = |blocks: &[u32], proj: Projection| -> Vec { - blocks.iter().map(|&b| WeightEdge { - tensor_name: format!("layers.{}.self_attn.q_proj.weight", b), - row_idx: 0, block: Some(b), projection: proj.clone(), - layer_type: LayerType::Attention, verb: Verb::Becomes, - l1_distance: 500, truth: NarsTruth::new(0.8, 0.9), - }).collect() + blocks + .iter() + .map(|&b| WeightEdge { + tensor_name: format!("layers.{}.self_attn.q_proj.weight", b), + row_idx: 0, + block: Some(b), + projection: proj.clone(), + layer_type: LayerType::Attention, + verb: Verb::Becomes, + l1_distance: 500, + truth: NarsTruth::new(0.8, 0.9), + }) + .collect() }; // v1: blocks 0-5 Q shifted @@ -1729,11 +1849,19 @@ mod tests { #[test] fn test_nars_head_belief_update() { let make = |blocks: &[u32], proj: Projection| -> Vec { - blocks.iter().map(|&b| WeightEdge { - tensor_name: String::new(), row_idx: 0, block: Some(b), - projection: proj.clone(), layer_type: LayerType::Attention, - verb: Verb::Becomes, l1_distance: 500, truth: NarsTruth::new(0.8, 0.9), - }).collect() + blocks + .iter() + .map(|&b| WeightEdge { + tensor_name: String::new(), + row_idx: 0, + block: Some(b), + projection: proj.clone(), + layer_type: LayerType::Attention, + verb: Verb::Becomes, + l1_distance: 500, + truth: NarsTruth::new(0.8, 0.9), + }) + .collect() }; let edges_v1 = make(&[0, 1], Projection::Q); @@ -1825,21 +1953,23 @@ mod tests { // ── Phase 1: Index all 5 models ── let models: Vec<(&str, &str, &str)> = vec![ - ("unsloth/Qwen3.5-27B-GGUF", - "Qwen3.5-27B-Q8_0.gguf", - "/tmp/qwen35_27b_base.bgz7"), - ("Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-GGUF", - "Qwen3.5-27B.Q8_0.gguf", - "/tmp/qwen35_27b_distilled_v1.bgz7"), - ("Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-v2-GGUF", - "Qwen3.5-27B.Q8_0.gguf", - "/tmp/qwen35_27b_distilled_v2.bgz7"), - ("unsloth/Qwen3.5-9B-GGUF", - "Qwen3.5-9B-Q8_0.gguf", - "/tmp/qwen35_9b_base.bgz7"), - ("Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled-GGUF", - "Qwen3.5-9B.Q8_0.gguf", - "/tmp/qwen35_9b_distilled.bgz7"), + ("unsloth/Qwen3.5-27B-GGUF", "Qwen3.5-27B-Q8_0.gguf", "/tmp/qwen35_27b_base.bgz7"), + ( + "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-GGUF", + "Qwen3.5-27B.Q8_0.gguf", + "/tmp/qwen35_27b_distilled_v1.bgz7", + ), + ( + "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-v2-GGUF", + "Qwen3.5-27B.Q8_0.gguf", + "/tmp/qwen35_27b_distilled_v2.bgz7", + ), + ("unsloth/Qwen3.5-9B-GGUF", "Qwen3.5-9B-Q8_0.gguf", "/tmp/qwen35_9b_base.bgz7"), + ( + "Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled-GGUF", + "Qwen3.5-9B.Q8_0.gguf", + "/tmp/qwen35_9b_distilled.bgz7", + ), ]; for (repo, filename, out_path) in &models { @@ -1857,7 +1987,8 @@ mod tests { .output() .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) .unwrap_or_default(); - let size: u64 = size_str.lines() + let size: u64 = size_str + .lines() .filter(|l| l.to_lowercase().starts_with("content-length:")) .last() .and_then(|l| l.split(':').nth(1)) @@ -1870,31 +2001,31 @@ mod tests { // Q8_0 uses f32 path (needs dequantization) let stats = super::super::gguf_indexer::stream_index_gguf( - &mut reader, &mut writer, + &mut reader, + &mut writer, Some(&|name, lt, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; eprintln!(" {:50} {:>8} → {:>6} ({:.0}×)", name, orig, comp, ratio); }), - ).expect("indexing failed"); + ) + .expect("indexing failed"); drop(writer); - eprintln!(" {} → {:.2} MB ({} tensors)", + eprintln!( + " {} → {:.2} MB ({} tensors)", out_path, std::fs::metadata(out_path).map(|m| m.len()).unwrap_or(0) as f64 / 1e6, - stats.tensors_indexed); + stats.tensors_indexed + ); } // ── Phase 2: Diff pairs ── let pairs: Vec<(&str, &str, &str)> = vec![ - ("/tmp/qwen35_27b_base.bgz7", "/tmp/qwen35_27b_distilled_v1.bgz7", - "27B base→v1"), - ("/tmp/qwen35_27b_base.bgz7", "/tmp/qwen35_27b_distilled_v2.bgz7", - "27B base→v2"), - ("/tmp/qwen35_27b_distilled_v1.bgz7", "/tmp/qwen35_27b_distilled_v2.bgz7", - "27B v1→v2"), - ("/tmp/qwen35_9b_base.bgz7", "/tmp/qwen35_9b_distilled.bgz7", - "9B base→distilled"), + ("/tmp/qwen35_27b_base.bgz7", "/tmp/qwen35_27b_distilled_v1.bgz7", "27B base→v1"), + ("/tmp/qwen35_27b_base.bgz7", "/tmp/qwen35_27b_distilled_v2.bgz7", "27B base→v2"), + ("/tmp/qwen35_27b_distilled_v1.bgz7", "/tmp/qwen35_27b_distilled_v2.bgz7", "27B v1→v2"), + ("/tmp/qwen35_9b_base.bgz7", "/tmp/qwen35_9b_distilled.bgz7", "9B base→distilled"), ]; let mut all_stats: Vec<(&str, DiffStats)> = Vec::new(); @@ -1915,16 +2046,18 @@ mod tests { eprintln!(); eprintln!("━━━ NARS Revision: integrated evidence ━━━"); - let refs: Vec<(&str, &DiffStats)> = all_stats.iter() - .map(|(l, s)| (*l, s)) - .collect(); + let refs: Vec<(&str, &DiffStats)> = all_stats.iter().map(|(l, s)| (*l, s)).collect(); let revised = revise_across_diffs(&refs); eprintln!(); for (proj, truth) in &revised { - eprintln!(" {:<12} → f={:.3} c={:.3} ({})", - proj, truth.frequency, truth.confidence, - if truth.frequency > 0.5 { "shifted" } else { "stable" }); + eprintln!( + " {:<12} → f={:.3} c={:.3} ({})", + proj, + truth.frequency, + truth.confidence, + if truth.frequency > 0.5 { "shifted" } else { "stable" } + ); } } @@ -1970,12 +2103,23 @@ mod tests { eprintln!("━━━ Maverick Gate Topology ━━━"); let total_redundant: usize = clusters.iter().map(|c| c.redundant_pairs).sum(); let total_pairs: usize = clusters.iter().map(|c| c.total_pairs).sum(); - eprintln!(" Overall redundancy: {}/{} pairs ({:.0}%)", - total_redundant, total_pairs, - if total_pairs > 0 { total_redundant as f64 / total_pairs as f64 * 100.0 } else { 0.0 }); + eprintln!( + " Overall redundancy: {}/{} pairs ({:.0}%)", + total_redundant, + total_pairs, + if total_pairs > 0 { + total_redundant as f64 / total_pairs as f64 * 100.0 + } else { + 0.0 + } + ); // NARS truth for expert redundancy - let f = if total_pairs > 0 { total_redundant as f32 / total_pairs as f32 } else { 0.0 }; + let f = if total_pairs > 0 { + total_redundant as f32 / total_pairs as f32 + } else { + 0.0 + }; let c = (1.0 - 1.0 / (1.0 + total_pairs as f32)).min(0.99); eprintln!(" NARS truth: f={:.3} c={:.3}", f, c); eprintln!(" Interpretation: {:.0}% of expert pairs are structurally interchangeable", f * 100.0); @@ -2024,11 +2168,24 @@ mod tests { let routing_dominated: usize = results.iter().filter(|(_, rd, _)| *rd).count(); eprintln!(); eprintln!(" Scaffold blocks: {}", scaffold_blocks.len()); - eprintln!(" Routing-dominated: {}/{} ({:.0}%)", - routing_dominated, results.len(), - if !results.is_empty() { routing_dominated as f64 / results.len() as f64 * 100.0 } else { 0.0 }); - eprintln!(" → {} = reasoning changes work THROUGH the router", - if routing_dominated > results.len() / 2 { "YES" } else { "PARTIAL" }); + eprintln!( + " Routing-dominated: {}/{} ({:.0}%)", + routing_dominated, + results.len(), + if !results.is_empty() { + routing_dominated as f64 / results.len() as f64 * 100.0 + } else { + 0.0 + } + ); + eprintln!( + " → {} = reasoning changes work THROUGH the router", + if routing_dominated > results.len() / 2 { + "YES" + } else { + "PARTIAL" + } + ); } // ════════════════════════════════════════════════════════════════════ @@ -2043,11 +2200,31 @@ mod tests { } const MODELS: [ModelSpec; 5] = [ - ModelSpec { repo: "Qwen/Qwen3.5-27B", shards: 11, prefix: "qwen35_27b_base" }, - ModelSpec { repo: "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled", shards: 11, prefix: "qwen35_27b_v1" }, - ModelSpec { repo: "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-v2", shards: 11, prefix: "qwen35_27b_v2" }, - ModelSpec { repo: "Qwen/Qwen3.5-9B", shards: 4, prefix: "qwen35_9b_base" }, - ModelSpec { repo: "Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled", shards: 4, prefix: "qwen35_9b_dist" }, + ModelSpec { + repo: "Qwen/Qwen3.5-27B", + shards: 11, + prefix: "qwen35_27b_base", + }, + ModelSpec { + repo: "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled", + shards: 11, + prefix: "qwen35_27b_v1", + }, + ModelSpec { + repo: "Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled-v2", + shards: 11, + prefix: "qwen35_27b_v2", + }, + ModelSpec { + repo: "Qwen/Qwen3.5-9B", + shards: 4, + prefix: "qwen35_9b_base", + }, + ModelSpec { + repo: "Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled", + shards: 4, + prefix: "qwen35_9b_dist", + }, ]; /// Generate safetensors shard filenames for a model. @@ -2066,8 +2243,8 @@ mod tests { /// Index a single model (all shards) via safetensors BF16. fn index_model_safetensors(model: &ModelSpec) { - use super::super::safetensors::stream_index_safetensors_bf16; use super::super::http_reader::HttpRangeReader; + use super::super::safetensors::stream_index_safetensors_bf16; use std::io::BufWriter; let filenames = shard_filenames(model.shards); @@ -2079,17 +2256,8 @@ mod tests { continue; } - let url = format!( - "https://huggingface.co/{}/resolve/main/{}", - model.repo, filename - ); - eprintln!( - "[{}] shard {}/{}: {}", - model.prefix, - shard_idx + 1, - model.shards, - filename - ); + let url = format!("https://huggingface.co/{}/resolve/main/{}", model.repo, filename); + eprintln!("[{}] shard {}/{}: {}", model.prefix, shard_idx + 1, model.shards, filename); // HEAD for content-length // Take the LAST content-length (after redirects) @@ -2135,10 +2303,7 @@ mod tests { /// Causal diff across matched shards of two models, aggregating edges + stats. fn causal_diff_sharded( - base_prefix: &str, - dist_prefix: &str, - n_shards: u32, - l1_threshold: u32, + base_prefix: &str, dist_prefix: &str, n_shards: u32, l1_threshold: u32, ) -> (Vec, DiffStats) { let base_paths = shard_bgz7_paths(base_prefix, n_shards); let dist_paths = shard_bgz7_paths(dist_prefix, n_shards); @@ -2186,8 +2351,7 @@ mod tests { entry.0 += shifted; entry.1 += total; if entry.1 > 0 { - entry.2 = (entry.2 * prev_total as f64 + mean_l1 * *total as f64) - / entry.1 as f64; + entry.2 = (entry.2 * prev_total as f64 + mean_l1 * *total as f64) / entry.1 as f64; } } @@ -2242,36 +2406,28 @@ mod tests { eprintln!(); eprintln!("════ Diff 1: 27B base → distilled v1 ════"); eprintln!(" What does Claude reasoning look like in weight space?"); - let (edges_1, stats_1) = causal_diff_sharded( - "qwen35_27b_base", "qwen35_27b_v1", 11, threshold, - ); + let (edges_1, stats_1) = causal_diff_sharded("qwen35_27b_base", "qwen35_27b_v1", 11, threshold); print_diff_summary("27B: base → v1", &stats_1, edges_1.len()); // ── Diff 2: base 27B → v2 ── eprintln!(); eprintln!("════ Diff 2: 27B base → distilled v2 ════"); eprintln!(" Did v2 refine the same heads or find new ones?"); - let (edges_2, stats_2) = causal_diff_sharded( - "qwen35_27b_base", "qwen35_27b_v2", 11, threshold, - ); + let (edges_2, stats_2) = causal_diff_sharded("qwen35_27b_base", "qwen35_27b_v2", 11, threshold); print_diff_summary("27B: base → v2", &stats_2, edges_2.len()); // ── Diff 3: v1 → v2 ── eprintln!(); eprintln!("════ Diff 3: 27B v1 → v2 (iteration delta) ════"); eprintln!(" Which heads converged vs overcorrected?"); - let (edges_3, stats_3) = causal_diff_sharded( - "qwen35_27b_v1", "qwen35_27b_v2", 11, threshold, - ); + let (edges_3, stats_3) = causal_diff_sharded("qwen35_27b_v1", "qwen35_27b_v2", 11, threshold); print_diff_summary("27B: v1 → v2", &stats_3, edges_3.len()); // ── Diff 4: 9B base → distilled ── eprintln!(); eprintln!("════ Diff 4: 9B base → distilled ════"); eprintln!(" Is the reasoning scaffold scale-invariant?"); - let (edges_4, stats_4) = causal_diff_sharded( - "qwen35_9b_base", "qwen35_9b_dist", 4, threshold, - ); + let (edges_4, stats_4) = causal_diff_sharded("qwen35_9b_base", "qwen35_9b_dist", 4, threshold); print_diff_summary("9B: base → distilled", &stats_4, edges_4.len()); // ── Phase 3: Reasoning scaffold detection ── @@ -2337,10 +2493,7 @@ mod tests { } else { "STABLE" }; - eprintln!( - " {:<12} → f={:.3} c={:.3} ({})", - proj, truth.frequency, truth.confidence, label - ); + eprintln!(" {:<12} → f={:.3} c={:.3} ({})", proj, truth.frequency, truth.confidence, label); } // ── Phase 5: Top shifted heads ── @@ -2352,10 +2505,7 @@ mod tests { sorted.sort_by(|a, b| b.1 .2.partial_cmp(&a.1 .2).unwrap()); for ((block, proj), (count, max_row, mean_l1)) in sorted.iter().take(20) { - eprintln!( - " Block {:>2} {:>10}: {}/{} shifted, mean_L1={:.0}", - block, proj, count, max_row, mean_l1 - ); + eprintln!(" Block {:>2} {:>10}: {}/{} shifted, mean_L1={:.0}", block, proj, count, max_row, mean_l1); } // ── Phase 6: Write results ── @@ -2371,10 +2521,7 @@ mod tests { report.push_str("| ID | Repo | Shards | Path |\n"); report.push_str("|---|---|---|---|\n"); for m in &MODELS { - report.push_str(&format!( - "| {} | {} | {} | safetensors BF16 |\n", - m.prefix, m.repo, m.shards - )); + report.push_str(&format!("| {} | {} | {} | safetensors BF16 |\n", m.prefix, m.repo, m.shards)); } report.push_str("\n## Diff Summary\n\n"); @@ -2396,18 +2543,9 @@ mod tests { } report.push_str("\n## Reasoning Scaffold\n\n"); - report.push_str(&format!( - "- **Scale-invariant blocks (27B∩9B)**: {:?}\n", - scale_invariant - )); - report.push_str(&format!( - "- **Capacity-dependent (27B only)**: {:?}\n", - capacity_dependent - )); - report.push_str(&format!( - "- **Converged (v1∩v2)**: {:?}\n", - converged - )); + report.push_str(&format!("- **Scale-invariant blocks (27B∩9B)**: {:?}\n", scale_invariant)); + report.push_str(&format!("- **Capacity-dependent (27B only)**: {:?}\n", capacity_dependent)); + report.push_str(&format!("- **Converged (v1∩v2)**: {:?}\n", converged)); report.push_str("\n## NARS Revised Truth Per Projection\n\n"); report.push_str("| Projection | Frequency | Confidence | Interpretation |\n"); @@ -2420,20 +2558,14 @@ mod tests { } else { "STABLE" }; - report.push_str(&format!( - "| {} | {:.3} | {:.3} | {} |\n", - proj, truth.frequency, truth.confidence, label - )); + report.push_str(&format!("| {} | {:.3} | {:.3} | {} |\n", proj, truth.frequency, truth.confidence, label)); } report.push_str("\n## Top 20 Shifted Heads (base→v1)\n\n"); report.push_str("| Block | Projection | Shifted/Total | Mean L1 |\n"); report.push_str("|---|---|---|---|\n"); for ((block, proj), (count, max_row, mean_l1)) in sorted.iter().take(20) { - report.push_str(&format!( - "| {} | {} | {}/{} | {:.0} |\n", - block, proj, count, max_row, mean_l1 - )); + report.push_str(&format!("| {} | {} | {}/{} | {:.0} |\n", block, proj, count, max_row, mean_l1)); } // Write to knowledge base @@ -2454,9 +2586,9 @@ mod tests { let effect = Base17 { dims: [500; 17] }; let candidates = vec![ - Base17 { dims: [400; 17] }, // closest + Base17 { dims: [400; 17] }, // closest Base17 { dims: [200; 17] }, - Base17 { dims: [100; 17] }, // farthest + Base17 { dims: [100; 17] }, // farthest ]; let chain = reverse_trace(&effect, &candidates, 5, 100000); diff --git a/src/hpc/causality.rs b/src/hpc/causality.rs index 874aa08b..cf91153c 100644 --- a/src/hpc/causality.rs +++ b/src/hpc/causality.rs @@ -49,8 +49,8 @@ pub mod qualia_dim { /// The three causality-relevant dimensions: warmth, social, sacredness. pub const CAUSALITY_DIMS: [usize; 3] = [ - qualia_dim::WARMTH, // 4 - qualia_dim::SOCIAL, // 6 + qualia_dim::WARMTH, // 4 + qualia_dim::SOCIAL, // 6 qualia_dim::SACREDNESS, // 8 ]; @@ -169,7 +169,10 @@ impl NarsTruthValue { /// Total ignorance: frequency 0.5, confidence 0. pub fn ignorance() -> Self { - Self { frequency: 0.5, confidence: 0.0 } + Self { + frequency: 0.5, + confidence: 0.0, + } } /// Expectation: `c * (f - 0.5) + 0.5`. @@ -236,9 +239,7 @@ pub struct CausalityDecomposition { /// assert_eq!(dec.sacredness_dir, ndarray::hpc::causality::CausalityDirection::Forward); /// ``` pub fn causality_decompose( - a: &PackedQualia, - b: &PackedQualia, - superposition: Option<&SuperpositionState>, + a: &PackedQualia, b: &PackedQualia, superposition: Option<&SuperpositionState>, ) -> CausalityDecomposition { let warmth_dir = CausalityDirection::from_qualia(a, b, qualia_dim::WARMTH); let social_dir = CausalityDirection::from_qualia(a, b, qualia_dim::SOCIAL); @@ -250,17 +251,10 @@ pub fn causality_decompose( NarsTruthValue::from_awareness(sp.states[qualia_dim::SOCIAL]), NarsTruthValue::from_awareness(sp.states[qualia_dim::SACREDNESS]), ), - _ => ( - NarsTruthValue::ignorance(), - NarsTruthValue::ignorance(), - NarsTruthValue::ignorance(), - ), + _ => (NarsTruthValue::ignorance(), NarsTruthValue::ignorance(), NarsTruthValue::ignorance()), }; - let overall_strength = (warmth_tv.expectation() - + social_tv.expectation() - + sacredness_tv.expectation()) - / 3.0; + let overall_strength = (warmth_tv.expectation() + social_tv.expectation() + sacredness_tv.expectation()) / 3.0; CausalityDecomposition { warmth_dir, diff --git a/src/hpc/clam.rs b/src/hpc/clam.rs index b5adb2b0..4b66f4d5 100644 --- a/src/hpc/clam.rs +++ b/src/hpc/clam.rs @@ -179,23 +179,12 @@ impl ClamTree { } /// Build with explicit config and count. - pub fn build_with_config( - data: &[u8], - vec_len: usize, - count: usize, - config: &BuildConfig, - ) -> Self { + pub fn build_with_config(data: &[u8], vec_len: usize, count: usize, config: &BuildConfig) -> Self { Self::build_with_fn(data, vec_len, count, config, hamming_inline) } /// Build a CLAM tree with a custom distance function. - pub fn build_with_fn( - data: &[u8], - vec_len: usize, - count: usize, - config: &BuildConfig, - dist_fn: DistanceFn, - ) -> Self { + pub fn build_with_fn(data: &[u8], vec_len: usize, count: usize, config: &BuildConfig, dist_fn: DistanceFn) -> Self { assert_eq!(data.len(), vec_len * count); if count == 0 { @@ -212,9 +201,7 @@ impl ClamTree { let mut nodes = Vec::with_capacity(2 * count); let mut rng = SplitMix64::new(0xDEAD_BEEF_CAFE_BABE); - Self::partition( - data, vec_len, &mut indices, 0, count, 0, config, &mut nodes, &mut rng, dist_fn, - ); + Self::partition(data, vec_len, &mut indices, 0, count, 0, config, &mut nodes, &mut rng, dist_fn); let mut num_leaves = 0usize; let mut leaf_radius_sum = 0u64; @@ -242,16 +229,8 @@ impl ClamTree { /// Recursive partition (Algorithm 1 from CAKES). #[allow(clippy::too_many_arguments)] fn partition( - data: &[u8], - vec_len: usize, - indices: &mut [usize], - start: usize, - end: usize, - depth: usize, - config: &BuildConfig, - nodes: &mut Vec, - rng: &mut SplitMix64, - dist_fn: DistanceFn, + data: &[u8], vec_len: usize, indices: &mut [usize], start: usize, end: usize, depth: usize, + config: &BuildConfig, nodes: &mut Vec, rng: &mut SplitMix64, dist_fn: DistanceFn, ) -> usize { let n = end - start; let node_idx = nodes.len(); @@ -377,15 +356,12 @@ impl ClamTree { && split < n; if should_split { - let left_idx = Self::partition( - data, vec_len, indices, start, start + split, depth + 1, config, nodes, rng, - dist_fn, - ); + let left_idx = + Self::partition(data, vec_len, indices, start, start + split, depth + 1, config, nodes, rng, dist_fn); nodes[node_idx].left = Some(left_idx); - let right_idx = Self::partition( - data, vec_len, indices, start + split, end, depth + 1, config, nodes, rng, dist_fn, - ); + let right_idx = + Self::partition(data, vec_len, indices, start + split, end, depth + 1, config, nodes, rng, dist_fn); nodes[node_idx].right = Some(right_idx); } @@ -411,19 +387,13 @@ impl ClamTree { } pub fn cluster_points<'a>( - &'a self, - cluster: &Cluster, - data: &'a [u8], - vec_len: usize, + &'a self, cluster: &Cluster, data: &'a [u8], vec_len: usize, ) -> impl Iterator + 'a { let start = cluster.offset; let end = start + cluster.cardinality; - self.reordered[start..end].iter().map(move |&orig_idx| { - ( - orig_idx, - &data[orig_idx * vec_len..(orig_idx + 1) * vec_len], - ) - }) + self.reordered[start..end] + .iter() + .map(move |&orig_idx| (orig_idx, &data[orig_idx * vec_len..(orig_idx + 1) * vec_len])) } pub fn cluster_member_indices(&self, cluster: &Cluster) -> Vec { @@ -503,11 +473,7 @@ impl ClamTree { let went_right = (path_bits >> (15 - bit_pos as u32)) & 1 == 1; - let next = if went_right { - cluster.right - } else { - cluster.left - }; + let next = if went_right { cluster.right } else { cluster.left }; match next { Some(child_idx) => { @@ -522,12 +488,7 @@ impl ClamTree { } /// Compute the CRP distribution for a cluster. - pub fn cluster_crp( - &self, - cluster: &Cluster, - data: &[u8], - vec_len: usize, - ) -> ClusterDistribution { + pub fn cluster_crp(&self, cluster: &Cluster, data: &[u8], vec_len: usize) -> ClusterDistribution { let center = self.center_data(cluster, data, vec_len); let mut distances: Vec = self .cluster_points(cluster, data, vec_len) @@ -630,13 +591,7 @@ pub struct KnnResult { } /// ρ-nearest neighbor search using triangle inequality. -pub fn rho_nn( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, -) -> RhoNnResult { +pub fn rho_nn(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], rho: u64) -> RhoNnResult { let mut hits = Vec::new(); let mut distance_calls = 0usize; let mut clusters_pruned = 0usize; @@ -703,13 +658,7 @@ pub fn rho_nn( } /// k-NN via Repeated ρ-NN search (CAKES Algorithm 4). -pub fn knn_repeated_rho( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - k: usize, -) -> KnnResult { +pub fn knn_repeated_rho(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], k: usize) -> KnnResult { let root = tree.root(); if root.cardinality == 0 { return KnnResult { @@ -766,13 +715,7 @@ pub fn knn_repeated_rho( } } -fn estimate_local_lfd( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, -) -> f64 { +fn estimate_local_lfd(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], rho: u64) -> f64 { let mut sum_inv_lfd = 0.0; let mut count = 0usize; @@ -808,13 +751,7 @@ fn estimate_local_lfd( } /// Depth-First Sieve k-NN search (CAKES Algorithm 6). -pub fn knn_dfs_sieve( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - k: usize, -) -> KnnResult { +pub fn knn_dfs_sieve(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], k: usize) -> KnnResult { use std::cmp::Reverse; use std::collections::BinaryHeap; @@ -894,15 +831,7 @@ pub fn knn_dfs_sieve( pub fn knn_brute(data: &[u8], vec_len: usize, query: &[u8], k: usize) -> KnnResult { let n = data.len() / vec_len; let mut dists: Vec<(usize, u64)> = (0..n) - .map(|i| { - ( - i, - bitwise::hamming_distance_raw( - query, - &data[i * vec_len..(i + 1) * vec_len], - ), - ) - }) + .map(|i| (i, bitwise::hamming_distance_raw(query, &data[i * vec_len..(i + 1) * vec_len]))) .collect(); dists.sort_unstable_by_key(|&(_, d)| d); dists.truncate(k); @@ -1055,11 +984,7 @@ impl CompressedTree { comp[node_idx] = Some(ClusterCompression { mode, unitary_cost, - recursive_cost: if cluster.is_leaf() { - unitary_cost - } else { - min_cost - }, + recursive_cost: if cluster.is_leaf() { unitary_cost } else { min_cost }, min_cost, }); } @@ -1073,19 +998,10 @@ impl CompressedTree { ]; let mut encoding_centers = vec![0usize; count]; - Self::assign_encodings( - tree, - data, - vec_len, - 0, - &cluster_modes, - &mut encodings, - &mut encoding_centers, - ); + Self::assign_encodings(tree, data, vec_len, 0, &cluster_modes, &mut encodings, &mut encoding_centers); let uncompressed_bytes = count * vec_len; - let compressed_bytes: usize = - encodings.iter().map(|e| e.storage_cost()).sum::() + count * 2; + let compressed_bytes: usize = encodings.iter().map(|e| e.storage_cost()).sum::() + count * 2; let ratio = if compressed_bytes > 0 { uncompressed_bytes as f64 / compressed_bytes as f64 } else { @@ -1122,13 +1038,8 @@ impl CompressedTree { } fn assign_encodings( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - node_idx: usize, - modes: &[CompressionMode], - encodings: &mut [XorDiffEncoding], - encoding_centers: &mut [usize], + tree: &ClamTree, data: &[u8], vec_len: usize, node_idx: usize, modes: &[CompressionMode], + encodings: &mut [XorDiffEncoding], encoding_centers: &mut [usize], ) { let cluster = &tree.nodes[node_idx]; let center = tree.center_data(cluster, data, vec_len); @@ -1143,15 +1054,7 @@ impl CompressedTree { Self::assign_encodings(tree, data, vec_len, left, modes, encodings, encoding_centers); } if let Some(right) = cluster.right { - Self::assign_encodings( - tree, - data, - vec_len, - right, - modes, - encodings, - encoding_centers, - ); + Self::assign_encodings(tree, data, vec_len, right, modes, encodings, encoding_centers); } } } @@ -1164,12 +1067,7 @@ impl CompressedTree { /// Compute Hamming distance from query to compressed point WITHOUT decompression. pub fn hamming_to_compressed( - &self, - query: &[u8], - point_idx: usize, - data: &[u8], - vec_len: usize, - dist_cache: &mut DistanceCache, + &self, query: &[u8], point_idx: usize, data: &[u8], vec_len: usize, dist_cache: &mut DistanceCache, dist_fn: DistanceFn, ) -> u64 { let center_idx = self.encoding_centers[point_idx]; @@ -1258,13 +1156,7 @@ impl ClamTree { /// Unlike the standalone `rho_nn` function, this method directly returns /// a simple `Vec<(usize, u64)>` suitable for piping into cascade verification /// via `clam_cascade_search`. - pub fn rho_nn_candidates( - &self, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, - ) -> Vec<(usize, u64)> { + pub fn rho_nn_candidates(&self, data: &[u8], vec_len: usize, query: &[u8], rho: u64) -> Vec<(usize, u64)> { if self.nodes.is_empty() { return Vec::new(); } @@ -1315,13 +1207,7 @@ impl ClamTree { /// /// Returns only non-Reject hits, limited to `top_k`, sorted by Hamming distance. pub fn clam_cascade_search( - tree: &ClamTree, - cascade: &Cascade, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, - top_k: usize, + tree: &ClamTree, cascade: &Cascade, data: &[u8], vec_len: usize, query: &[u8], rho: u64, top_k: usize, ) -> Vec { let candidates = tree.rho_nn_candidates(data, vec_len, query, rho); @@ -1369,12 +1255,7 @@ impl ClamTree { /// The cascade's stroke-1 (partial prefix scan) becomes partially redundant /// since CLAM provides geometrically tight candidates via triangle inequality. pub fn rho_nn_cascade( - &self, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, - cascade: &Cascade, + &self, data: &[u8], vec_len: usize, query: &[u8], rho: u64, cascade: &Cascade, ) -> ClamCascadeResult { // Phase 1a: CLAM rho-NN search let rho_result = rho_nn(self, data, vec_len, query, rho); @@ -1398,12 +1279,8 @@ impl ClamTree { // Phase 2: SPO Distance Harvest // ═══════════════════════════════════════════════════════════════════ -use super::causality::{ - causality_decompose, CausalityDecomposition, NarsTruthValue, -}; -use super::bf16_truth::{ - AwarenessState, PackedQualia, -}; +use super::bf16_truth::{AwarenessState, PackedQualia}; +use super::causality::{causality_decompose, CausalityDecomposition, NarsTruthValue}; use super::node::{Node, SPO, S__, _P_, __O}; use super::plane::Distance as PlaneDistance; @@ -1435,10 +1312,7 @@ pub struct CausalHit { /// 3. Uses `CausalityDecomposition` to extract directional relationships /// 4. Returns enriched results with causal metadata pub fn spo_distance_harvest( - hits: &[RankedHit], - query_node: &mut Node, - hit_nodes: &mut [Node], - query_qualia: &PackedQualia, + hits: &[RankedHit], query_node: &mut Node, hit_nodes: &mut [Node], query_qualia: &PackedQualia, hit_qualias: &[PackedQualia], ) -> Vec { let mut results = Vec::with_capacity(hits.len()); @@ -1485,7 +1359,9 @@ pub fn spo_distance_harvest( // Use the full SPO distance to derive awareness-based truth let d_spo = query_node.distance(hit_node, SPO); let truth = match d_spo { - PlaneDistance::Measured { disagreement, overlap, .. } => { + PlaneDistance::Measured { + disagreement, overlap, .. + } => { if overlap == 0 { NarsTruthValue::ignorance() } else { @@ -1537,11 +1413,7 @@ impl ClamTree { /// Wraps `CompressedTree::compress()` — each point is XOR-diff encoded /// relative to its nearest cluster center, yielding a compact representation /// suitable for compressive search. - pub fn compress_database( - &self, - data: &[u8], - vec_len: usize, - ) -> CompressedTree { + pub fn compress_database(&self, data: &[u8], vec_len: usize) -> CompressedTree { let count = data.len() / vec_len; CompressedTree::compress(self, data, vec_len, count) } @@ -1552,12 +1424,7 @@ impl ClamTree { /// distances from the encoding diffs without full decompression. /// Cost per point: O(num_diffs) instead of O(vec_len). pub fn query_compressed( - &self, - compressed: &CompressedTree, - data: &[u8], - vec_len: usize, - query: &[u8], - rho: u64, + &self, compressed: &CompressedTree, data: &[u8], vec_len: usize, query: &[u8], rho: u64, ) -> CompressedSearchResult { let count = data.len() / vec_len; let mut cache = DistanceCache::new(); @@ -1604,9 +1471,7 @@ impl ClamTree { if orig_idx >= count { continue; } - let d = compressed.hamming_to_compressed( - query, orig_idx, data, vec_len, &mut cache, dist_fn, - ); + let d = compressed.hamming_to_compressed(query, orig_idx, data, vec_len, &mut cache, dist_fn); distance_calls += 1; if d <= rho { hits.push((orig_idx, d)); @@ -1705,12 +1570,7 @@ impl ClamTree { /// /// Threshold is in [0, 1]. A threshold of 0.75 flags the top ~25% most /// anomalous points (those in high-LFD leaf clusters). - pub fn flag_anomalies( - &self, - data: &[u8], - vec_len: usize, - threshold: f64, - ) -> Vec { + pub fn flag_anomalies(&self, data: &[u8], vec_len: usize, threshold: f64) -> Vec { self.anomaly_scores(data, vec_len) .into_iter() .filter(|a| a.score >= threshold) @@ -1743,10 +1603,7 @@ mod tests { } fn make_clustered_data( - num_clusters: usize, - points_per_cluster: usize, - vec_len: usize, - noise_bytes: usize, + num_clusters: usize, points_per_cluster: usize, vec_len: usize, noise_bytes: usize, ) -> Vec { let count = num_clusters * points_per_cluster; let mut data = vec![0u8; count * vec_len]; @@ -1771,13 +1628,7 @@ mod tests { data } - fn linear_knn( - data: &[u8], - vec_len: usize, - count: usize, - query: &[u8], - k: usize, - ) -> Vec<(usize, u64)> { + fn linear_knn(data: &[u8], vec_len: usize, count: usize, query: &[u8], k: usize) -> Vec<(usize, u64)> { let mut dists: Vec<(usize, u64)> = (0..count) .map(|i| { let point = &data[i * vec_len..(i + 1) * vec_len]; @@ -1967,11 +1818,7 @@ mod tests { .filter(|&(_, d)| d <= rho) .collect(); - assert_eq!( - result.hits.len(), - ground_truth.len(), - "ρ-NN should have perfect recall" - ); + assert_eq!(result.hits.len(), ground_truth.len(), "ρ-NN should have perfect recall"); } #[test] @@ -2163,10 +2010,7 @@ mod tests { let dist_q_point_exact = hamming_inline(&query, &point); let dist_q_point_compressed = enc.hamming_from_query(&query, ¢er, dist_q_center); - assert_eq!( - dist_q_point_compressed, dist_q_point_exact, - "Compressive Hamming should match exact" - ); + assert_eq!(dist_q_point_compressed, dist_q_point_exact, "Compressive Hamming should match exact"); } #[test] @@ -2243,14 +2087,7 @@ mod tests { for i in 0..count { let exact = hamming_inline(query, &data[i * vec_len..(i + 1) * vec_len]); - let comp = compressed.hamming_to_compressed( - query, - i, - &data, - vec_len, - &mut cache, - hamming_inline, - ); + let comp = compressed.hamming_to_compressed(query, i, &data, vec_len, &mut cache, hamming_inline); assert_eq!(comp, exact, "Compressive distance mismatch at point {}", i); } } @@ -2284,10 +2121,7 @@ mod tests { let result = tree.rho_nn_cascade(&data, vec_len, query, rho, &cascade); // Should find the exact match at index 0 - assert!( - result.hits.iter().any(|r| r.index == 0 && r.hamming == 0), - "Should find exact match at index 0" - ); + assert!(result.hits.iter().any(|r| r.index == 0 && r.hamming == 0), "Should find exact match at index 0"); assert!(result.clam_candidates > 0); assert!(result.cascade_survivors > 0); assert!(result.cascade_survivors <= result.clam_candidates); @@ -2355,8 +2189,18 @@ mod tests { fn test_spo_distance_harvest_basic() { // Create some ranked hits let hits = vec![ - RankedHit { index: 0, hamming: 10, precise: f64::NAN, band: Band::Foveal }, - RankedHit { index: 1, hamming: 50, precise: f64::NAN, band: Band::Near }, + RankedHit { + index: 0, + hamming: 10, + precise: f64::NAN, + band: Band::Foveal, + }, + RankedHit { + index: 1, + hamming: 50, + precise: f64::NAN, + band: Band::Near, + }, ]; let mut query_node = Node::random(42); @@ -2366,13 +2210,7 @@ mod tests { hit_q0.resonance[4] = 10; // warmth forward let hit_qualias = vec![hit_q0, PackedQualia::zero()]; - let results = spo_distance_harvest( - &hits, - &mut query_node, - &mut hit_nodes, - &query_qualia, - &hit_qualias, - ); + let results = spo_distance_harvest(&hits, &mut query_node, &mut hit_nodes, &query_qualia, &hit_qualias); assert_eq!(results.len(), 2); // First hit should have S/P/O distances (random nodes have encounters) @@ -2390,22 +2228,19 @@ mod tests { #[test] fn test_spo_distance_harvest_out_of_bounds() { // Hit index beyond available nodes - let hits = vec![ - RankedHit { index: 99, hamming: 10, precise: f64::NAN, band: Band::Good }, - ]; + let hits = vec![RankedHit { + index: 99, + hamming: 10, + precise: f64::NAN, + band: Band::Good, + }]; let mut query_node = Node::random(42); let mut hit_nodes = vec![Node::random(100)]; // only 1 node let query_qualia = PackedQualia::zero(); let hit_qualias = vec![PackedQualia::zero()]; - let results = spo_distance_harvest( - &hits, - &mut query_node, - &mut hit_nodes, - &query_qualia, - &hit_qualias, - ); + let results = spo_distance_harvest(&hits, &mut query_node, &mut hit_nodes, &query_qualia, &hit_qualias); assert_eq!(results.len(), 1); // Out-of-bounds index should use defaults @@ -2416,22 +2251,19 @@ mod tests { #[test] fn test_spo_distance_harvest_truth_values() { - let hits = vec![ - RankedHit { index: 0, hamming: 5, precise: f64::NAN, band: Band::Foveal }, - ]; + let hits = vec![RankedHit { + index: 0, + hamming: 5, + precise: f64::NAN, + band: Band::Foveal, + }]; let mut query_node = Node::random(1); let mut hit_nodes = vec![Node::random(2)]; let query_qualia = PackedQualia::zero(); let hit_qualias = vec![PackedQualia::zero()]; - let results = spo_distance_harvest( - &hits, - &mut query_node, - &mut hit_nodes, - &query_qualia, - &hit_qualias, - ); + let results = spo_distance_harvest(&hits, &mut query_node, &mut hit_nodes, &query_qualia, &hit_qualias); assert_eq!(results.len(), 1); // Truth value should be derived from SPO distance @@ -2442,22 +2274,19 @@ mod tests { #[test] fn test_spo_distance_harvest_preserves_hit_info() { - let hits = vec![ - RankedHit { index: 0, hamming: 42, precise: f64::NAN, band: Band::Near }, - ]; + let hits = vec![RankedHit { + index: 0, + hamming: 42, + precise: f64::NAN, + band: Band::Near, + }]; let mut query_node = Node::random(10); let mut hit_nodes = vec![Node::random(20)]; let query_qualia = PackedQualia::zero(); let hit_qualias = vec![PackedQualia::zero()]; - let results = spo_distance_harvest( - &hits, - &mut query_node, - &mut hit_nodes, - &query_qualia, - &hit_qualias, - ); + let results = spo_distance_harvest(&hits, &mut query_node, &mut hit_nodes, &query_qualia, &hit_qualias); assert_eq!(results[0].index, 0); assert_eq!(results[0].hamming, 42); @@ -2512,10 +2341,7 @@ mod tests { let result = tree.query_compressed(&compressed, &data, vec_len, query, rho); // Should find exact match at index 0 - assert!( - result.hits.iter().any(|&(idx, d)| idx == 0 && d == 0), - "Should find exact match at index 0" - ); + assert!(result.hits.iter().any(|&(idx, d)| idx == 0 && d == 0), "Should find exact match at index 0"); assert!(result.distance_calls > 0); } @@ -2538,13 +2364,9 @@ mod tests { // Compressed search should find the same set of points let compressed_indices: std::collections::HashSet = compressed_result.hits.iter().map(|&(idx, _)| idx).collect(); - let exact_indices: std::collections::HashSet = - exact_result.hits.iter().map(|&(idx, _)| idx).collect(); + let exact_indices: std::collections::HashSet = exact_result.hits.iter().map(|&(idx, _)| idx).collect(); - assert_eq!( - compressed_indices, exact_indices, - "Compressed search should find same results as exact search" - ); + assert_eq!(compressed_indices, exact_indices, "Compressed search should find same results as exact search"); } #[test] @@ -2579,8 +2401,7 @@ mod tests { assert_eq!(scores.len(), count); for score in &scores { - assert!(score.score >= 0.0 && score.score <= 1.0, - "Score {} out of range [0,1]", score.score); + assert!(score.score >= 0.0 && score.score <= 1.0, "Score {} out of range [0,1]", score.score); assert!(score.index < count); } } @@ -2604,10 +2425,7 @@ mod tests { } else { AwarenessState::Noise }; - assert_eq!( - score.awareness, expected_awareness, - "Awareness mismatch for score={}", score.score - ); + assert_eq!(score.awareness, expected_awareness, "Awareness mismatch for score={}", score.score); } } @@ -2623,11 +2441,7 @@ mod tests { // All flagged should have score >= 0.75 for anomaly in &flagged { - assert!( - anomaly.score >= 0.75, - "Flagged anomaly has score {} < 0.75", - anomaly.score - ); + assert!(anomaly.score >= 0.75, "Flagged anomaly has score {} < 0.75", anomaly.score); } // Count manually from all_scores @@ -2665,7 +2479,7 @@ mod tests { let vec_len = 64; // Create 3 tight clusters plus some noisy outliers let mut data = make_clustered_data(3, 20, vec_len, 2); // tight clusters - // Add 5 random outliers + // Add 5 random outliers let mut rng = SplitMix64::new(999); for _ in 0..5 { let mut outlier = vec![0u8; vec_len]; @@ -2742,9 +2556,7 @@ mod tests { let tree = ClamTree::build(&data, vec_len, 3); let cascade = Cascade::from_threshold(vec_len as u64 * 4, vec_len); let query = &data[0..vec_len]; - let hits = clam_cascade_search( - &tree, &cascade, &data, vec_len, query, vec_len as u64 * 8, 10, - ); + let hits = clam_cascade_search(&tree, &cascade, &data, vec_len, query, vec_len as u64 * 8, 10); assert!(!hits.is_empty()); // No Reject band hits should survive for h in &hits { @@ -2760,9 +2572,7 @@ mod tests { let tree = ClamTree::build(&data, vec_len, 3); let cascade = Cascade::from_threshold(vec_len as u64 * 8, vec_len); let query = &data[0..vec_len]; - let hits = clam_cascade_search( - &tree, &cascade, &data, vec_len, query, u64::MAX, 5, - ); + let hits = clam_cascade_search(&tree, &cascade, &data, vec_len, query, u64::MAX, 5); assert!(hits.len() <= 5); } } diff --git a/src/hpc/clam_compress.rs b/src/hpc/clam_compress.rs index 20feed52..5fae8f48 100644 --- a/src/hpc/clam_compress.rs +++ b/src/hpc/clam_compress.rs @@ -254,11 +254,7 @@ impl CompressedTree { comp[node_idx] = Some(ClusterCompression { mode, unitary_cost, - recursive_cost: if cluster.is_leaf() { - unitary_cost - } else { - min_cost - }, + recursive_cost: if cluster.is_leaf() { unitary_cost } else { min_cost }, min_cost, }); } @@ -276,19 +272,13 @@ impl CompressedTree { let mut encoding_centers = vec![0usize; count]; Self::assign_encodings( - tree, - data, - vec_len, - 0, // root - &cluster_modes, - &mut encodings, - &mut encoding_centers, + tree, data, vec_len, 0, // root + &cluster_modes, &mut encodings, &mut encoding_centers, ); // Compute stats let uncompressed_bytes = count * vec_len; - let compressed_bytes: usize = - encodings.iter().map(|e| e.storage_cost()).sum::() + count * 2; // 2 bytes per point for center reference overhead + let compressed_bytes: usize = encodings.iter().map(|e| e.storage_cost()).sum::() + count * 2; // 2 bytes per point for center reference overhead let ratio = if compressed_bytes > 0 { uncompressed_bytes as f64 / compressed_bytes as f64 } else { @@ -326,13 +316,8 @@ impl CompressedTree { /// Recursively assign encodings to points. fn assign_encodings( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - node_idx: usize, - modes: &[CompressionMode], - encodings: &mut [XorDiffEncoding], - encoding_centers: &mut [usize], + tree: &ClamTree, data: &[u8], vec_len: usize, node_idx: usize, modes: &[CompressionMode], + encodings: &mut [XorDiffEncoding], encoding_centers: &mut [usize], ) { let cluster = &tree.nodes[node_idx]; let center = tree.center_data(cluster, data, vec_len); @@ -346,26 +331,10 @@ impl CompressedTree { } else { // Recursive: delegate to children if let Some(left) = cluster.left { - Self::assign_encodings( - tree, - data, - vec_len, - left, - modes, - encodings, - encoding_centers, - ); + Self::assign_encodings(tree, data, vec_len, left, modes, encodings, encoding_centers); } if let Some(right) = cluster.right { - Self::assign_encodings( - tree, - data, - vec_len, - right, - modes, - encodings, - encoding_centers, - ); + Self::assign_encodings(tree, data, vec_len, right, modes, encodings, encoding_centers); } } } @@ -389,12 +358,7 @@ impl CompressedTree { /// /// Cost: O(num_diffs) per point instead of O(vec_len). pub fn hamming_to_compressed( - &self, - query: &[u8], - point_idx: usize, - data: &[u8], - vec_len: usize, - dist_cache: &mut DistanceCache, + &self, query: &[u8], point_idx: usize, data: &[u8], vec_len: usize, dist_cache: &mut DistanceCache, dist_fn: fn(&[u8], &[u8]) -> u64, ) -> u64 { let center_idx = self.encoding_centers[point_idx]; @@ -474,9 +438,9 @@ fn postorder_indices(tree: &ClamTree) -> Vec { #[cfg(test)] mod tests { - use super::*; - use super::super::clam::ClamTree; use super::super::bitwise; + use super::super::clam::ClamTree; + use super::*; /// Simple SplitMix64 RNG for deterministic test data generation. struct SplitMix64(u64); @@ -505,10 +469,7 @@ mod tests { /// Make clustered data: groups of similar vectors. fn make_clustered_data( - num_clusters: usize, - points_per_cluster: usize, - vec_len: usize, - noise_bytes: usize, + num_clusters: usize, points_per_cluster: usize, vec_len: usize, noise_bytes: usize, ) -> Vec { let count = num_clusters * points_per_cluster; let mut data = vec![0u8; count * vec_len]; @@ -576,10 +537,7 @@ mod tests { let dist_q_point_compressed = enc.hamming_from_query(&query, ¢er, dist_q_center); - assert_eq!( - dist_q_point_compressed, dist_q_point_exact, - "Compressive Hamming should match exact Hamming" - ); + assert_eq!(dist_q_point_compressed, dist_q_point_exact, "Compressive Hamming should match exact Hamming"); } #[test] @@ -594,9 +552,7 @@ mod tests { println!( "Random data compression: {:.2}x ({} -> {} bytes)", - compressed.stats.ratio, - compressed.stats.uncompressed_bytes, - compressed.stats.compressed_bytes + compressed.stats.ratio, compressed.stats.uncompressed_bytes, compressed.stats.compressed_bytes ); // Random data: compression ratio may be < 1 (expansion) @@ -606,11 +562,7 @@ mod tests { for i in 0..count { let decompressed = compressed.decompress_point(i, &data, vec_len); let original = &data[i * vec_len..(i + 1) * vec_len]; - assert_eq!( - &decompressed, original, - "Decompressed point {} should match original", - i - ); + assert_eq!(&decompressed, original, "Decompressed point {} should match original", i); } } @@ -630,9 +582,7 @@ mod tests { println!( "Clustered data compression: {:.2}x ({} -> {} bytes)", - compressed.stats.ratio, - compressed.stats.uncompressed_bytes, - compressed.stats.compressed_bytes + compressed.stats.ratio, compressed.stats.uncompressed_bytes, compressed.stats.compressed_bytes ); // Clustered data with low noise should compress well @@ -668,19 +618,9 @@ mod tests { // Compare compressive search distances to exact distances for i in 0..count { let exact = bitwise::hamming_distance_raw(query, &data[i * vec_len..(i + 1) * vec_len]); - let comp = compressed.hamming_to_compressed( - query, - i, - &data, - vec_len, - &mut cache, - bitwise::hamming_distance_raw, - ); - assert_eq!( - comp, exact, - "Compressive Hamming for point {} should match exact ({} vs {})", - i, comp, exact - ); + let comp = + compressed.hamming_to_compressed(query, i, &data, vec_len, &mut cache, bitwise::hamming_distance_raw); + assert_eq!(comp, exact, "Compressive Hamming for point {} should match exact ({} vs {})", i, comp, exact); } } diff --git a/src/hpc/clam_search.rs b/src/hpc/clam_search.rs index 81c96ac9..6316b055 100644 --- a/src/hpc/clam_search.rs +++ b/src/hpc/clam_search.rs @@ -178,13 +178,7 @@ pub fn rho_nn(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], rho: u /// /// This is NOT the fastest CAKES algorithm (Depth-First Sieve is), /// but it's the simplest and demonstrates the LFD-guided radius ratchet. -pub fn knn_repeated_rho( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - k: usize, -) -> KnnResult { +pub fn knn_repeated_rho(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], k: usize) -> KnnResult { let root = tree.root(); if root.cardinality == 0 { return KnnResult { @@ -304,13 +298,7 @@ fn estimate_local_lfd(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8] /// /// where d = LFD, N = metric entropy, |C_bar| = mean leaf cardinality. /// This is sublinear in n when LFD << embedding dimension. -pub fn knn_dfs_sieve( - tree: &ClamTree, - data: &[u8], - vec_len: usize, - query: &[u8], - k: usize, -) -> KnnResult { +pub fn knn_dfs_sieve(tree: &ClamTree, data: &[u8], vec_len: usize, query: &[u8], k: usize) -> KnnResult { let mut distance_calls = 0usize; let mut clusters_pruned = 0usize; @@ -400,9 +388,9 @@ pub fn knn_dfs_sieve( #[cfg(test)] mod tests { - use super::*; - use super::super::clam::ClamTree; use super::super::bitwise; + use super::super::clam::ClamTree; + use super::*; /// Simple SplitMix64 RNG for deterministic test data generation. struct SplitMix64(u64); @@ -430,13 +418,7 @@ mod tests { } /// Linear scan for ground truth. - fn linear_knn( - data: &[u8], - vec_len: usize, - count: usize, - query: &[u8], - k: usize, - ) -> Vec<(usize, u64)> { + fn linear_knn(data: &[u8], vec_len: usize, count: usize, query: &[u8], k: usize) -> Vec<(usize, u64)> { let mut dists: Vec<(usize, u64)> = (0..count) .map(|i| { let point = &data[i * vec_len..(i + 1) * vec_len]; @@ -495,11 +477,7 @@ mod tests { .collect(); // Hamming is a metric -> exact recall - assert_eq!( - result.hits.len(), - ground_truth.len(), - "rho-NN should have perfect recall for metric distances" - ); + assert_eq!(result.hits.len(), ground_truth.len(), "rho-NN should have perfect recall for metric distances"); } #[test] @@ -521,10 +499,7 @@ mod tests { // Check exact recall: our k-th hit should match linear scan's k-th hit distance let our_max_dist = result.hits.last().unwrap().1; let gt_max_dist = ground_truth.last().unwrap().1; - assert_eq!( - our_max_dist, gt_max_dist, - "k-NN should find exact same max distance as linear scan" - ); + assert_eq!(our_max_dist, gt_max_dist, "k-NN should find exact same max distance as linear scan"); println!( "Repeated rho-NN: {} distance calls, {} pruned (vs {} linear)", @@ -551,10 +526,7 @@ mod tests { // Verify exact recall let our_max_dist = result.hits.last().unwrap().1; let gt_max_dist = ground_truth.last().unwrap().1; - assert_eq!( - our_max_dist, gt_max_dist, - "DFS Sieve should find exact same max distance as linear scan" - ); + assert_eq!(our_max_dist, gt_max_dist, "DFS Sieve should find exact same max distance as linear scan"); println!( "DFS Sieve: {} distance calls, {} pruned (vs {} linear)", @@ -583,10 +555,7 @@ mod tests { ); // With random data, speedup may be modest, but should prune something - assert!( - result.clusters_pruned > 0, - "should prune at least some clusters" - ); + assert!(result.clusters_pruned > 0, "should prune at least some clusters"); } #[test] diff --git a/src/hpc/cogrecord.rs b/src/hpc/cogrecord.rs index 26034960..b3b02ecc 100644 --- a/src/hpc/cogrecord.rs +++ b/src/hpc/cogrecord.rs @@ -3,8 +3,8 @@ //! Each container is an `Array` of 16384 bytes, queryable via //! Hamming distance (VPOPCNTDQ) or int8 dot product. -use crate::imp_prelude::*; use super::bitwise::BitwiseOps; +use crate::imp_prelude::*; /// Size of each container in bytes (16384 = 131072 bits). pub const CONTAINER_BYTES: usize = 16384; @@ -72,13 +72,13 @@ impl Default for CogRecord { impl CogRecord { /// Create a new CogRecord from 4 containers. - pub fn new( - meta: Array, - cam: Array, - btree: Array, - embed: Array, - ) -> Self { - Self { meta, cam, btree, embed } + pub fn new(meta: Array, cam: Array, btree: Array, embed: Array) -> Self { + Self { + meta, + cam, + btree, + embed, + } } /// Create a zero-initialized CogRecord. @@ -172,19 +172,14 @@ impl CogRecord { /// Batch sweep across multiple CogRecords. /// /// Returns `SweepResult` for all candidates that pass the threshold. -pub fn sweep_cogrecords( - query: &CogRecord, - candidates: &[CogRecord], - thresholds: &[u64; 4], -) -> Vec { +pub fn sweep_cogrecords(query: &CogRecord, candidates: &[CogRecord], thresholds: &[u64; 4]) -> Vec { candidates .iter() .enumerate() .filter_map(|(i, candidate)| { - query.sweep_adaptive(candidate, thresholds).map(|distances| SweepResult { - index: i, - distances, - }) + query + .sweep_adaptive(candidate, thresholds) + .map(|distances| SweepResult { index: i, distances }) }) .collect() } diff --git a/src/hpc/compression_curves.rs b/src/hpc/compression_curves.rs index 8dbaf745..0c7bcb6c 100644 --- a/src/hpc/compression_curves.rs +++ b/src/hpc/compression_curves.rs @@ -19,7 +19,9 @@ // ============================================================================ fn prng(state: &mut u64) -> u64 { - *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + *state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *state } @@ -65,7 +67,11 @@ fn cosine_sim(a: &[f64], b: &[f64]) -> f64 { } fn l2_dist(a: &[f64], b: &[f64]) -> f64 { - a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum::().sqrt() + a.iter() + .zip(b) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() } // ============================================================================ @@ -137,7 +143,11 @@ impl RandomProjection { let mut s = seed_state(seed); let scale = 1.0 / (target_dim as f64).sqrt(); let matrix = (0..target_dim) - .map(|_| (0..source_dim).map(|_| prng_normal(&mut s) * scale).collect()) + .map(|_| { + (0..source_dim) + .map(|_| prng_normal(&mut s) * scale) + .collect() + }) .collect(); RandomProjection { matrix } } @@ -161,8 +171,8 @@ impl RandomProjection { /// Simplified PQ: divide vector into M subvectors, each quantized to k_bits. struct ProductQuantizer { - m: usize, // number of sub-vectors - k_bits: usize, // bits per sub-quantizer (2^k_bits centroids) + m: usize, // number of sub-vectors + k_bits: usize, // bits per sub-quantizer (2^k_bits centroids) sub_dim: usize, codebooks: Vec>>, // M × 2^k_bits × sub_dim } @@ -219,7 +229,12 @@ impl ProductQuantizer { }) .collect(); - ProductQuantizer { m, k_bits, sub_dim, codebooks } + ProductQuantizer { + m, + k_bits, + sub_dim, + codebooks, + } } fn encode(&self, v: &[f64]) -> Vec { @@ -261,7 +276,11 @@ const PHI: f64 = std::f64::consts::GOLDEN_RATIO; const fn golden_shift(d: usize) -> usize { let raw = (d as f64 / (PHI * PHI)) as usize; - if raw % 2 == 0 { raw + 1 } else { raw } + if raw % 2 == 0 { + raw + 1 + } else { + raw + } } fn cyclic_shift_dyn(bits: &[u64], shift: usize) -> Vec { @@ -335,7 +354,9 @@ fn to_ranks(v: &[f64]) -> Vec { fn spearman(a: &[f64], b: &[f64]) -> f64 { let n = a.len(); - if n < 2 { return 0.0; } + if n < 2 { + return 0.0; + } let ra = to_ranks(a); let rb = to_ranks(b); let ma: f64 = ra.iter().sum::() / n as f64; @@ -350,7 +371,9 @@ fn spearman(a: &[f64], b: &[f64]) -> f64 { va += da * da; vb += db * db; } - if va < 1e-10 || vb < 1e-10 { return 0.0; } + if va < 1e-10 || vb < 1e-10 { + return 0.0; + } cov / (va.sqrt() * vb.sqrt()) } @@ -361,11 +384,11 @@ fn spearman(a: &[f64], b: &[f64]) -> f64 { struct CompressionResult { method: &'static str, bits_per_vector: usize, - source_bits: usize, // original f64 vector size in bits + source_bits: usize, // original f64 vector size in bits recall_at_10: f64, spearman_rho: f64, cluster_purity: f64, - compression_ratio: f64, // source_bits / bits_per_vector + compression_ratio: f64, // source_bits / bits_per_vector } impl CompressionResult { @@ -378,12 +401,7 @@ impl CompressionResult { } /// Run a full benchmark at a given dimension. -fn benchmark_at_dim( - dim: usize, - n_vectors: usize, - n_clusters: usize, - cluster_spread: f64, -) -> Vec { +fn benchmark_at_dim(dim: usize, n_vectors: usize, n_clusters: usize, cluster_spread: f64) -> Vec { let source_bits = dim * 64; // f64 per dimension // ── Generate clustered data ────────────────────────────────── @@ -414,14 +432,18 @@ fn benchmark_at_dim( let mut gt_ranked: Vec<(usize, f64)> = gt_dists.iter().copied().enumerate().collect(); gt_ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); - let _gt_top10: std::collections::HashSet = - gt_ranked[..10.min(gt_ranked.len())].iter().map(|&(i, _)| i).collect(); + let _gt_top10: std::collections::HashSet = gt_ranked[..10.min(gt_ranked.len())] + .iter() + .map(|&(i, _)| i) + .collect(); let mut results = Vec::new(); // ── Method 1: SimHash at various bit widths ────────────────── for &n_bits in &[64, 128, 256, 512, 1024, 2048] { - if n_bits > dim * 4 { continue; } // skip if oversized + if n_bits > dim * 4 { + continue; + } // skip if oversized let projector = SimHashProjector::new(dim, n_bits, 12345); let hashes: Vec> = vectors.iter().map(|v| projector.hash(v)).collect(); @@ -469,11 +491,11 @@ fn benchmark_at_dim( // ── Method 3: Random Projection + Binarize ─────────────────── for &target_bits in &[64, 128, 256, 512, 1024] { - if target_bits > dim * 2 { continue; } + if target_bits > dim * 2 { + continue; + } let rp = RandomProjection::new(dim, target_bits, 54321); - let projected: Vec> = vectors.iter() - .map(|v| rp.project_and_binarize(v)) - .collect(); + let projected: Vec> = vectors.iter().map(|v| rp.project_and_binarize(v)).collect(); let dists: Vec = (1..n_vectors) .map(|i| hamming_bytes(&projected[qi], &projected[i]) as f64) @@ -496,8 +518,12 @@ fn benchmark_at_dim( // ── Method 4: Product Quantization ─────────────────────────── for &(m, k_bits) in &[(8, 4), (8, 8), (16, 4), (16, 8), (32, 4), (32, 8)] { - if m > dim { continue; } - if dim % m != 0 { continue; } + if m > dim { + continue; + } + if dim % m != 0 { + continue; + } let pq = ProductQuantizer::train(dim, m, k_bits, &vectors, 99999); let codes: Vec> = vectors.iter().map(|v| pq.encode(v)).collect(); @@ -524,7 +550,9 @@ fn benchmark_at_dim( // Convert f64 → sign bits → cyclic permutation bundle. // This is the correct approach for float→binary compression via SPO. for &plane_bits in &[128, 256, 512, 1024, 2048, 4096, 8192] { - if plane_bits > dim { continue; } + if plane_bits > dim { + continue; + } let planes: Vec<(Vec, Vec, Vec)> = vectors .iter() .map(|v| { @@ -654,55 +682,70 @@ fn dimension_sweep(dims: &[usize], n_per_cluster: usize, n_clusters: usize) -> V // For each method, compute separation let methods: Vec<(&str, Box]) -> (usize, Vec)>)> = vec![ - ("BinaryQuant", Box::new(|vecs: &[Vec]| { - let bq: Vec> = vecs.iter().map(|v| binary_quantize(v)).collect(); - let bits = vecs[0].len(); - let dists: Vec = (1..vecs.len()) - .map(|i| hamming_bytes(&bq[0], &bq[i]) as f64) - .collect(); - (bits, dists) - })), - ("SimHash-256", Box::new(move |vecs: &[Vec]| { - let proj = SimHashProjector::new(dim, 256, 12345); - let h: Vec> = vecs.iter().map(|v| proj.hash(v)).collect(); - let dists: Vec = (1..vecs.len()) - .map(|i| hamming_bytes(&h[0], &h[i]) as f64) - .collect(); - (256, dists) - })), - ("SimHash-1024", Box::new(move |vecs: &[Vec]| { - let bits = 1024.min(dim * 2); - let proj = SimHashProjector::new(dim, bits, 12345); - let h: Vec> = vecs.iter().map(|v| proj.hash(v)).collect(); - let dists: Vec = (1..vecs.len()) - .map(|i| hamming_bytes(&h[0], &h[i]) as f64) - .collect(); - (bits, dists) - })), - ("SPO-Sign", Box::new(move |vecs: &[Vec]| { - // SPO with sign-bit encoding: best float→binary→bundle path - let plane_bits = dim.min(1024); // clamp to dim - let third = plane_bits / 64; - if third == 0 { return (1, vec![0.0; vecs.len() - 1]); } - let planes: Vec<(Vec, Vec, Vec)> = vecs - .iter() - .map(|v| { - let s = sign_bits(&v[..plane_bits.min(dim)], third); - let p_start = (dim / 3).min(dim - 1); - let o_start = (2 * dim / 3).min(dim - 1); - let p = sign_bits(&v[p_start..], third); - let o = sign_bits(&v[o_start..], third); - (s, p, o) - }) - .collect(); - let bundles: Vec> = planes.iter() - .map(|(s, p, o)| spo_bundle(s, p, o, plane_bits)) - .collect(); - let dists: Vec = (1..vecs.len()) - .map(|i| hamming_u64(&bundles[0], &bundles[i]) as f64) - .collect(); - (plane_bits, dists) - })), + ( + "BinaryQuant", + Box::new(|vecs: &[Vec]| { + let bq: Vec> = vecs.iter().map(|v| binary_quantize(v)).collect(); + let bits = vecs[0].len(); + let dists: Vec = (1..vecs.len()) + .map(|i| hamming_bytes(&bq[0], &bq[i]) as f64) + .collect(); + (bits, dists) + }), + ), + ( + "SimHash-256", + Box::new(move |vecs: &[Vec]| { + let proj = SimHashProjector::new(dim, 256, 12345); + let h: Vec> = vecs.iter().map(|v| proj.hash(v)).collect(); + let dists: Vec = (1..vecs.len()) + .map(|i| hamming_bytes(&h[0], &h[i]) as f64) + .collect(); + (256, dists) + }), + ), + ( + "SimHash-1024", + Box::new(move |vecs: &[Vec]| { + let bits = 1024.min(dim * 2); + let proj = SimHashProjector::new(dim, bits, 12345); + let h: Vec> = vecs.iter().map(|v| proj.hash(v)).collect(); + let dists: Vec = (1..vecs.len()) + .map(|i| hamming_bytes(&h[0], &h[i]) as f64) + .collect(); + (bits, dists) + }), + ), + ( + "SPO-Sign", + Box::new(move |vecs: &[Vec]| { + // SPO with sign-bit encoding: best float→binary→bundle path + let plane_bits = dim.min(1024); // clamp to dim + let third = plane_bits / 64; + if third == 0 { + return (1, vec![0.0; vecs.len() - 1]); + } + let planes: Vec<(Vec, Vec, Vec)> = vecs + .iter() + .map(|v| { + let s = sign_bits(&v[..plane_bits.min(dim)], third); + let p_start = (dim / 3).min(dim - 1); + let o_start = (2 * dim / 3).min(dim - 1); + let p = sign_bits(&v[p_start..], third); + let o = sign_bits(&v[o_start..], third); + (s, p, o) + }) + .collect(); + let bundles: Vec> = planes + .iter() + .map(|(s, p, o)| spo_bundle(s, p, o, plane_bits)) + .collect(); + let dists: Vec = (1..vecs.len()) + .map(|i| hamming_u64(&bundles[0], &bundles[i]) as f64) + .collect(); + (plane_bits, dists) + }), + ), ]; for (name, method_fn) in &methods { @@ -722,9 +765,21 @@ fn dimension_sweep(dims: &[usize], n_per_cluster: usize, n_clusters: usize) -> V } } - let intra_mean = if intra.is_empty() { 1.0 } else { intra.iter().sum::() / intra.len() as f64 }; - let inter_mean = if inter.is_empty() { 1.0 } else { inter.iter().sum::() / inter.len() as f64 }; - let sep = if intra_mean > 1e-10 { inter_mean / intra_mean } else { inter_mean }; + let intra_mean = if intra.is_empty() { + 1.0 + } else { + intra.iter().sum::() / intra.len() as f64 + }; + let inter_mean = if inter.is_empty() { + 1.0 + } else { + inter.iter().sum::() / inter.len() as f64 + }; + let sep = if intra_mean > 1e-10 { + inter_mean / intra_mean + } else { + inter_mean + }; results.push(DimSweepResult { dim, @@ -764,7 +819,11 @@ fn find_sweet_spot(dim: usize, n_vectors: usize) -> Vec { for m in 0..per_cluster { let seed = (c * per_cluster + m) as u64 * 13 + 77; let noise = random_f64_vec(dim, seed); - let v: Vec = center.iter().zip(&noise).map(|(&c, &n)| c + n * spread).collect(); + let v: Vec = center + .iter() + .zip(&noise) + .map(|(&c, &n)| c + n * spread) + .collect(); vectors.push(v); labels.push(c); } @@ -783,9 +842,13 @@ fn find_sweet_spot(dim: usize, n_vectors: usize) -> Vec { let mut prev_bits = 0; for &plane_bits in &bit_widths { - if plane_bits > dim { continue; } + if plane_bits > dim { + continue; + } let n_words = plane_bits / 64; - if n_words == 0 { continue; } + if n_words == 0 { + continue; + } let planes: Vec<(Vec, Vec, Vec)> = vectors .iter() .map(|v| { @@ -894,9 +957,15 @@ fn native_binary_benchmark( let source_bits = plane_bits * 3; // Generate random SPO triples - let nodes_s: Vec> = (0..n_nodes).map(|i| random_plane(plane_words, i as u64 * 3 + 1)).collect(); - let nodes_p: Vec> = (0..n_nodes).map(|i| random_plane(plane_words, i as u64 * 3 + 2)).collect(); - let nodes_o: Vec> = (0..n_nodes).map(|i| random_plane(plane_words, i as u64 * 3 + 3)).collect(); + let nodes_s: Vec> = (0..n_nodes) + .map(|i| random_plane(plane_words, i as u64 * 3 + 1)) + .collect(); + let nodes_p: Vec> = (0..n_nodes) + .map(|i| random_plane(plane_words, i as u64 * 3 + 2)) + .collect(); + let nodes_o: Vec> = (0..n_nodes) + .map(|i| random_plane(plane_words, i as u64 * 3 + 3)) + .collect(); // Structured: first 50 share similar S, next 50 share similar P let base_s = random_plane(plane_words, 99999); @@ -983,10 +1052,14 @@ fn native_binary_benchmark( // Simple truncation baseline for &target_bits in &target_dims { - if target_bits > source_bits { continue; } + if target_bits > source_bits { + continue; + } let third = target_bits / 3; let third_words = third / 64; - if third_words == 0 { continue; } + if third_words == 0 { + continue; + } let remainder_words = (target_bits - 2 * third) / 64; let trunc_random: Vec> = (0..n_nodes) @@ -1025,7 +1098,9 @@ fn native_binary_benchmark( // SimHash baseline (random bit sampling from concatenated S+P+O) for &target_bits in &target_dims { - if target_bits > source_bits { continue; } + if target_bits > source_bits { + continue; + } // Sample target_bits random positions from 3×plane_bits let mut s = seed_state(77777); let indices: Vec = (0..target_bits) @@ -1090,19 +1165,18 @@ fn native_binary_benchmark( } /// Compute average Spearman ρ and recall@10 from bundle hamming vs exact distances. -fn avg_metrics( - exact_dists: &[Vec], - bundles: &[Vec], - n_queries: usize, - n_nodes: usize, -) -> (f64, f64) { +fn avg_metrics(exact_dists: &[Vec], bundles: &[Vec], n_queries: usize, n_nodes: usize) -> (f64, f64) { let mut rhos = Vec::new(); let mut recalls = Vec::new(); for q in 0..n_queries { - let gt: Vec = (0..n_nodes).filter(|&i| i != q) - .map(|i| exact_dists[q][i] as f64).collect(); - let ap: Vec = (0..n_nodes).filter(|&i| i != q) - .map(|i| hamming_u64(&bundles[q], &bundles[i]) as f64).collect(); + let gt: Vec = (0..n_nodes) + .filter(|&i| i != q) + .map(|i| exact_dists[q][i] as f64) + .collect(); + let ap: Vec = (0..n_nodes) + .filter(|&i| i != q) + .map(|i| hamming_u64(&bundles[q], &bundles[i]) as f64) + .collect(); rhos.push(spearman(>, &ap)); recalls.push(recall_at_k(>, &ap, 10)); } @@ -1112,19 +1186,18 @@ fn avg_metrics( } /// Compute avg metrics from raw u64 vectors (hamming distance). -fn avg_metrics_raw( - exact_dists: &[Vec], - compressed: &[Vec], - n_queries: usize, - n_nodes: usize, -) -> (f64, f64) { +fn avg_metrics_raw(exact_dists: &[Vec], compressed: &[Vec], n_queries: usize, n_nodes: usize) -> (f64, f64) { let mut rhos = Vec::new(); let mut recalls = Vec::new(); for q in 0..n_queries { - let gt: Vec = (0..n_nodes).filter(|&i| i != q) - .map(|i| exact_dists[q][i] as f64).collect(); - let ap: Vec = (0..n_nodes).filter(|&i| i != q) - .map(|i| hamming_u64(&compressed[q], &compressed[i]) as f64).collect(); + let gt: Vec = (0..n_nodes) + .filter(|&i| i != q) + .map(|i| exact_dists[q][i] as f64) + .collect(); + let ap: Vec = (0..n_nodes) + .filter(|&i| i != q) + .map(|i| hamming_u64(&compressed[q], &compressed[i]) as f64) + .collect(); rhos.push(spearman(>, &ap)); recalls.push(recall_at_k(>, &ap, 10)); } @@ -1154,21 +1227,23 @@ mod tests { eprintln!("\n╔══════════════════════════════════════════════════════════════════════════════╗"); eprintln!("║ COMPRESSION CURVE: {}D vectors, n={} ║", dim, n); eprintln!("╠══════════════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", - "Method", "Bits", "Ratio", "ρ", "R@10", "Purity"); + eprintln!( + "║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", + "Method", "Bits", "Ratio", "ρ", "R@10", "Purity" + ); eprintln!("╠══════════════════════════════════════════════════════════════════════════════╣"); for r in &results { - eprintln!("║ {:>12} │ {:>6} │ {:>7.1}× │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", - r.method, r.bits_per_vector, r.compression_ratio, r.spearman_rho, - r.recall_at_10, r.cluster_purity); + eprintln!( + "║ {:>12} │ {:>6} │ {:>7.1}× │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", + r.method, r.bits_per_vector, r.compression_ratio, r.spearman_rho, r.recall_at_10, r.cluster_purity + ); } eprintln!("╚══════════════════════════════════════════════════════════════════════════════╝"); // Sanity: all methods should have non-negative ρ for r in &results { - assert!(r.spearman_rho > -0.5, - "{} at {} bits has ρ={:.4}", r.method, r.bits_per_vector, r.spearman_rho); + assert!(r.spearman_rho > -0.5, "{} at {} bits has ρ={:.4}", r.method, r.bits_per_vector, r.spearman_rho); } } @@ -1185,22 +1260,28 @@ mod tests { eprintln!("\n╔══════════════════════════════════════════════════════════════════════════════╗"); eprintln!("║ COMPRESSION CURVE: {}D vectors (sentence embeddings), n={} ║", dim, n); eprintln!("╠══════════════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", - "Method", "Bits", "Ratio", "ρ", "R@10", "Purity"); + eprintln!( + "║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", + "Method", "Bits", "Ratio", "ρ", "R@10", "Purity" + ); eprintln!("╠══════════════════════════════════════════════════════════════════════════════╣"); let mut sorted = results.iter().collect::>(); sorted.sort_by(|a, b| a.bits_per_vector.cmp(&b.bits_per_vector)); for r in &sorted { - eprintln!("║ {:>12} │ {:>6} │ {:>7.1}× │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", - r.method, r.bits_per_vector, r.compression_ratio, r.spearman_rho, - r.recall_at_10, r.cluster_purity); + eprintln!( + "║ {:>12} │ {:>6} │ {:>7.1}× │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", + r.method, r.bits_per_vector, r.compression_ratio, r.spearman_rho, r.recall_at_10, r.cluster_purity + ); } eprintln!("╚══════════════════════════════════════════════════════════════════════════════╝"); // At 768d, PQ and SimHash should both work reasonably - let pq_best = results.iter().filter(|r| r.method == "PQ").map(|r| r.spearman_rho) + let pq_best = results + .iter() + .filter(|r| r.method == "PQ") + .map(|r| r.spearman_rho) .fold(0.0f64, f64::max); assert!(pq_best > 0.3, "PQ best ρ={:.4} too low at 768d", pq_best); } @@ -1217,28 +1298,34 @@ mod tests { eprintln!("\n╔════════════════════════════════════════════════════════════════════════════════════╗"); eprintln!("║ DIMENSION × SEPARATION QUALITY CURVE ║"); eprintln!("╠════════════════════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>4} │ {:>14} │ {:>6} │ {:>9} │ {:>9} │ {:>5} │ {:>6} ║", - "Dim", "Method", "Bits", "Intra", "Inter", "Sep", "ρ"); + eprintln!( + "║ {:>4} │ {:>14} │ {:>6} │ {:>9} │ {:>9} │ {:>5} │ {:>6} ║", + "Dim", "Method", "Bits", "Intra", "Inter", "Sep", "ρ" + ); eprintln!("╠════════════════════════════════════════════════════════════════════════════════════╣"); for r in &results { - eprintln!("║ {:>4} │ {:>14} │ {:>6} │ {:>9.1} │ {:>9.1} │ {:>5.2} │ {:>6.3} ║", - r.dim, r.method, r.bits, r.intra_cluster_mean, - r.inter_cluster_mean, r.separation_ratio, r.spearman_rho); + eprintln!( + "║ {:>4} │ {:>14} │ {:>6} │ {:>9.1} │ {:>9.1} │ {:>5.2} │ {:>6.3} ║", + r.dim, r.method, r.bits, r.intra_cluster_mean, r.inter_cluster_mean, r.separation_ratio, r.spearman_rho + ); } eprintln!("╚════════════════════════════════════════════════════════════════════════════════════╝"); // Higher dimensions should give better separation for all methods for method in &["BinaryQuant", "SimHash-256", "SPO-8K"] { - let method_results: Vec<&DimSweepResult> = results.iter() - .filter(|r| r.method == *method) - .collect(); + let method_results: Vec<&DimSweepResult> = results.iter().filter(|r| r.method == *method).collect(); if method_results.len() >= 2 { let first_sep = method_results[0].separation_ratio; let last_sep = method_results.last().unwrap().separation_ratio; - eprintln!(" {} separation: dim={} → {:.2}, dim={} → {:.2}", - method, method_results[0].dim, first_sep, - method_results.last().unwrap().dim, last_sep); + eprintln!( + " {} separation: dim={} → {:.2}, dim={} → {:.2}", + method, + method_results[0].dim, + first_sep, + method_results.last().unwrap().dim, + last_sep + ); } } } @@ -1256,8 +1343,7 @@ mod tests { eprintln!("\n╔══════════════════════════════════════════════════════════════════╗"); eprintln!("║ PRECISION SWEET SPOT: SPO Bundle at {}D, n={} ║", dim, n); eprintln!("╠══════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>7} │ {:>8} │ {:>7} │ {:>12} │ {:>7} ║", - "Bits", "ρ", "R@10", "Δρ/bit (×1e6)", "Grade"); + eprintln!("║ {:>7} │ {:>8} │ {:>7} │ {:>12} │ {:>7} ║", "Bits", "ρ", "R@10", "Δρ/bit (×1e6)", "Grade"); eprintln!("╠══════════════════════════════════════════════════════════════════╣"); let mut peak_idx = 0; @@ -1275,16 +1361,24 @@ mod tests { "low" }; - eprintln!("║ {:>7} │ {:>8.4} │ {:>7.2} │ {:>12.2} │ {:>7} ║", - p.bits, p.spearman_rho, p.recall_at_10, p.marginal_gain * 1e6, grade); + eprintln!( + "║ {:>7} │ {:>8.4} │ {:>7.2} │ {:>12.2} │ {:>7} ║", + p.bits, + p.spearman_rho, + p.recall_at_10, + p.marginal_gain * 1e6, + grade + ); } eprintln!("╚══════════════════════════════════════════════════════════════════╝"); if !points.is_empty() { - eprintln!("\n >>> Sweet spot: {} bits (best marginal ρ gain per bit)", - points[peak_idx].bits); - eprintln!(" >>> Peak ρ: {:.4} at {} bits", - points.last().unwrap().spearman_rho, points.last().unwrap().bits); + eprintln!("\n >>> Sweet spot: {} bits (best marginal ρ gain per bit)", points[peak_idx].bits); + eprintln!( + " >>> Peak ρ: {:.4} at {} bits", + points.last().unwrap().spearman_rho, + points.last().unwrap().bits + ); } } @@ -1307,7 +1401,11 @@ mod tests { for m in 0..per_cluster { let seed = (c * per_cluster + m) as u64 * 19 + 7; let noise = random_f64_vec(dim, seed); - let v: Vec = center.iter().zip(&noise).map(|(&c, &n)| c + n * spread).collect(); + let v: Vec = center + .iter() + .zip(&noise) + .map(|(&c, &n)| c + n * spread) + .collect(); vectors.push(v); labels.push(c); } @@ -1331,7 +1429,9 @@ mod tests { // SimHash @ 1024 bits let sh = SimHashProjector::new(dim, budget, 12345); let sh_h: Vec> = vectors.iter().map(|v| sh.hash(v)).collect(); - let sh_d: Vec = (1..n).map(|i| hamming_bytes(&sh_h[0], &sh_h[i]) as f64).collect(); + let sh_d: Vec = (1..n) + .map(|i| hamming_bytes(&sh_h[0], &sh_h[i]) as f64) + .collect(); let sh_rho = spearman(>_dists, &sh_d); let sh_r10 = recall_at_k(>_dists, &sh_d, 10); let sh_pur = cluster_purity_knn(&sh_d, &labels, 0, 10); @@ -1340,7 +1440,9 @@ mod tests { // Random Projection @ 1024 bits let rp = RandomProjection::new(dim, budget, 54321); let rp_h: Vec> = vectors.iter().map(|v| rp.project_and_binarize(v)).collect(); - let rp_d: Vec = (1..n).map(|i| hamming_bytes(&rp_h[0], &rp_h[i]) as f64).collect(); + let rp_d: Vec = (1..n) + .map(|i| hamming_bytes(&rp_h[0], &rp_h[i]) as f64) + .collect(); let rp_rho = spearman(>_dists, &rp_d); let rp_r10 = recall_at_k(>_dists, &rp_d, 10); let rp_pur = cluster_purity_knn(&rp_d, &labels, 0, 10); @@ -1350,7 +1452,9 @@ mod tests { if dim % 128 == 0 { let pq = ProductQuantizer::train(dim, 128, 8, &vectors, 99999); let codes: Vec> = vectors.iter().map(|v| pq.encode(v)).collect(); - let pq_d: Vec = (1..n).map(|i| pq.asymmetric_dist(&vectors[0], &codes[i])).collect(); + let pq_d: Vec = (1..n) + .map(|i| pq.asymmetric_dist(&vectors[0], &codes[i])) + .collect(); let pq_rho = spearman(>_dists, &pq_d); let pq_r10 = recall_at_k(>_dists, &pq_d, 10); let pq_pur = cluster_purity_knn(&pq_d, &labels, 0, 10); @@ -1359,7 +1463,8 @@ mod tests { // SPO-Sign @ 1024 bits (sign-bit encoding) let n_words = budget / 64; - let planes: Vec<(Vec, Vec, Vec)> = vectors.iter() + let planes: Vec<(Vec, Vec, Vec)> = vectors + .iter() .map(|v| { let s = sign_bits(&v[..budget.min(dim)], n_words); let p_start = (dim / 3).min(dim - 1); @@ -1369,10 +1474,13 @@ mod tests { (s, p, o) }) .collect(); - let bundles: Vec> = planes.iter() + let bundles: Vec> = planes + .iter() .map(|(s, p, o)| spo_bundle(s, p, o, budget)) .collect(); - let spo_d: Vec = (1..n).map(|i| hamming_u64(&bundles[0], &bundles[i]) as f64).collect(); + let spo_d: Vec = (1..n) + .map(|i| hamming_u64(&bundles[0], &bundles[i]) as f64) + .collect(); let spo_rho = spearman(>_dists, &spo_d); let spo_r10 = recall_at_k(>_dists, &spo_d, 10); let spo_pur = cluster_purity_knn(&spo_d, &labels, 0, 10); @@ -1380,7 +1488,9 @@ mod tests { // BinaryQuant @ 256 bits (1 bit per dim, natural size) let bq: Vec> = vectors.iter().map(|v| binary_quantize(v)).collect(); - let bq_d: Vec = (1..n).map(|i| hamming_bytes(&bq[0], &bq[i]) as f64).collect(); + let bq_d: Vec = (1..n) + .map(|i| hamming_bytes(&bq[0], &bq[i]) as f64) + .collect(); let bq_rho = spearman(>_dists, &bq_d); let bq_r10 = recall_at_k(>_dists, &bq_d, 10); let bq_pur = cluster_purity_knn(&bq_d, &labels, 0, 10); @@ -1412,7 +1522,11 @@ mod tests { for m in 0..per_cluster { let seed = (c * per_cluster + m) as u64 * 23 + 11; let noise = random_f64_vec(dim, seed); - let v: Vec = center.iter().zip(&noise).map(|(&c, &n)| c + n * spread).collect(); + let v: Vec = center + .iter() + .zip(&noise) + .map(|(&c, &n)| c + n * spread) + .collect(); vectors.push(v); labels.push(c); } @@ -1421,8 +1535,7 @@ mod tests { eprintln!("\n╔══════════════════════════════════════════════════════════════════════╗"); eprintln!("║ MULTI-QUERY PRECISION ({} queries, {}D, n={}) ║", n_queries, dim, n); eprintln!("╠══════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>8} ║", - "Method", "Bits", "mean_ρ", "mean_R@10", "std_ρ"); + eprintln!("║ {:>12} │ {:>6} │ {:>8} │ {:>8} │ {:>8} ║", "Method", "Bits", "mean_ρ", "mean_R@10", "std_ρ"); eprintln!("╠══════════════════════════════════════════════════════════════════════╣"); struct MethodConfig { @@ -1431,14 +1544,38 @@ mod tests { } let configs = vec![ - MethodConfig { name: "SimHash", bits: 256 }, - MethodConfig { name: "SimHash", bits: 1024 }, - MethodConfig { name: "SimHash", bits: 4096 }, - MethodConfig { name: "SPO-Bndl", bits: 1024 }, - MethodConfig { name: "SPO-Bndl", bits: 4096 }, - MethodConfig { name: "SPO-Bndl", bits: 8192 }, - MethodConfig { name: "SPO-Bndl", bits: 16384 }, - MethodConfig { name: "BinQuant", bits: dim }, + MethodConfig { + name: "SimHash", + bits: 256, + }, + MethodConfig { + name: "SimHash", + bits: 1024, + }, + MethodConfig { + name: "SimHash", + bits: 4096, + }, + MethodConfig { + name: "SPO-Bndl", + bits: 1024, + }, + MethodConfig { + name: "SPO-Bndl", + bits: 4096, + }, + MethodConfig { + name: "SPO-Bndl", + bits: 8192, + }, + MethodConfig { + name: "SPO-Bndl", + bits: 16384, + }, + MethodConfig { + name: "BinQuant", + bits: dim, + }, ]; for cfg in &configs { @@ -1455,15 +1592,19 @@ mod tests { "SimHash" => { let proj = SimHashProjector::new(dim, cfg.bits, 12345); let hashes: Vec> = vectors.iter().map(|v| proj.hash(v)).collect(); - (0..n).filter(|&i| i != qi) + (0..n) + .filter(|&i| i != qi) .map(|i| hamming_bytes(&hashes[qi], &hashes[i]) as f64) .collect() } "SPO-Bndl" => { let bw = cfg.bits.min(dim); let nw = bw / 64; - if nw == 0 { continue; } - let planes: Vec<(Vec, Vec, Vec)> = vectors.iter() + if nw == 0 { + continue; + } + let planes: Vec<(Vec, Vec, Vec)> = vectors + .iter() .map(|v| { let s = sign_bits(&v[..bw.min(dim)], nw); let ps = (dim / 3).min(dim - 1); @@ -1473,16 +1614,19 @@ mod tests { (s, p, o) }) .collect(); - let bundles: Vec> = planes.iter() + let bundles: Vec> = planes + .iter() .map(|(s, p, o)| spo_bundle(s, p, o, bw)) .collect(); - (0..n).filter(|&i| i != qi) + (0..n) + .filter(|&i| i != qi) .map(|i| hamming_u64(&bundles[qi], &bundles[i]) as f64) .collect() } "BinQuant" => { let bq: Vec> = vectors.iter().map(|v| binary_quantize(v)).collect(); - (0..n).filter(|&i| i != qi) + (0..n) + .filter(|&i| i != qi) .map(|i| hamming_bytes(&bq[qi], &bq[i]) as f64) .collect() } @@ -1499,8 +1643,10 @@ mod tests { let mean_recall = recalls.iter().sum::() / recalls.len() as f64; let std_rho = (rhos.iter().map(|r| (r - mean_rho).powi(2)).sum::() / rhos.len() as f64).sqrt(); - eprintln!("║ {:>12} │ {:>6} │ {:>8.4} │ {:>8.2} │ {:>8.4} ║", - cfg.name, cfg.bits, mean_rho, mean_recall, std_rho); + eprintln!( + "║ {:>12} │ {:>6} │ {:>8.4} │ {:>8.2} │ {:>8.4} ║", + cfg.name, cfg.bits, mean_rho, mean_recall, std_rho + ); } eprintln!("╚══════════════════════════════════════════════════════════════════════╝"); @@ -1521,8 +1667,10 @@ mod tests { eprintln!("╠═══════════════════════════════════════════════════════════════════════════╣"); eprintln!("║ Source: {} bits ({} bytes) per vector ║", dim * 64, dim * 8); eprintln!("╠═══════════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>12} │ {:>8} │ {:>8} │ {:>6} │ {:>8} │ {:>7} ║", - "Method", "Abs(bits)", "Abs(KB)", "Ratio", "ρ", "R@10"); + eprintln!( + "║ {:>12} │ {:>8} │ {:>8} │ {:>6} │ {:>8} │ {:>7} ║", + "Method", "Abs(bits)", "Abs(KB)", "Ratio", "ρ", "R@10" + ); eprintln!("╠═══════════════════════════════════════════════════════════════════════════╣"); let mut sorted = results.iter().collect::>(); @@ -1530,9 +1678,10 @@ mod tests { for r in &sorted { let kb = r.bits_per_vector as f64 / 8192.0; - eprintln!("║ {:>12} │ {:>8} │ {:>7.3} │ {:>5.0}× │ {:>8.4} │ {:>7.2} ║", - r.method, r.bits_per_vector, kb, r.compression_ratio, - r.spearman_rho, r.recall_at_10); + eprintln!( + "║ {:>12} │ {:>8} │ {:>7.3} │ {:>5.0}× │ {:>8.4} │ {:>7.2} ║", + r.method, r.bits_per_vector, kb, r.compression_ratio, r.spearman_rho, r.recall_at_10 + ); } eprintln!("╚═══════════════════════════════════════════════════════════════════════════╝"); @@ -1564,48 +1713,78 @@ mod tests { eprintln!("║ NATIVE BINARY COMPRESSION: 3×16Kbit SPO planes → bundle (n={}) ║", n_nodes); eprintln!("║ Source: 49152 bits (6KB) per SPO triple ║"); eprintln!("╠══════════════════════════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>13} │ {:>6} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", - "Method", "Bits", "Ratio", "ρ_rand", "ρ_struct", "R@10_r", "R@10_s"); + eprintln!( + "║ {:>13} │ {:>6} │ {:>6} │ {:>8} │ {:>8} │ {:>7} │ {:>7} ║", + "Method", "Bits", "Ratio", "ρ_rand", "ρ_struct", "R@10_r", "R@10_s" + ); eprintln!("╠══════════════════════════════════════════════════════════════════════════════════════════╣"); let source_bits = plane_words * 64 * 3; let mut sorted = results.iter().collect::>(); - sorted.sort_by(|a, b| { - a.method.cmp(&b.method).then(a.bits.cmp(&b.bits).reverse()) - }); + sorted.sort_by(|a, b| a.method.cmp(&b.method).then(a.bits.cmp(&b.bits).reverse())); for r in &sorted { let ratio = source_bits as f64 / r.bits as f64; - eprintln!("║ {:>13} │ {:>6} │ {:>5.1}× │ {:>8.4} │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", - r.method, r.bits, ratio, r.rho_random, r.rho_structured, - r.recall_at_10_random, r.recall_at_10_structured); + eprintln!( + "║ {:>13} │ {:>6} │ {:>5.1}× │ {:>8.4} │ {:>8.4} │ {:>7.2} │ {:>7.2} ║", + r.method, + r.bits, + ratio, + r.rho_random, + r.rho_structured, + r.recall_at_10_random, + r.recall_at_10_structured + ); } eprintln!("╚══════════════════════════════════════════════════════════════════════════════════════════╝"); // Find crossover point: where does cyclic bundle beat truncation? eprintln!("\n Method comparison at each bit budget:"); - let budgets: Vec = results.iter().map(|r| r.bits).collect::>() - .into_iter().collect::>(); + let budgets: Vec = results + .iter() + .map(|r| r.bits) + .collect::>() + .into_iter() + .collect::>(); let mut budgets_sorted = budgets; budgets_sorted.sort_by(|a, b| b.cmp(a)); for &bits in &budgets_sorted { - let bundle = results.iter().find(|r| r.method == "CyclicBundle" && r.bits == bits); - let trunc = results.iter().find(|r| r.method == "Truncation" && r.bits == bits); - let sample = results.iter().find(|r| r.method == "BitSample" && r.bits == bits); + let bundle = results + .iter() + .find(|r| r.method == "CyclicBundle" && r.bits == bits); + let trunc = results + .iter() + .find(|r| r.method == "Truncation" && r.bits == bits); + let sample = results + .iter() + .find(|r| r.method == "BitSample" && r.bits == bits); if let (Some(b), Some(t)) = (bundle, trunc) { - let winner = if b.rho_structured > t.rho_structured { "Bundle" } else { "Trunc" }; + let winner = if b.rho_structured > t.rho_structured { + "Bundle" + } else { + "Trunc" + }; let s_winner = if let Some(s) = sample { - if s.rho_structured > b.rho_structured.max(t.rho_structured) { "Sample" } else { winner } - } else { winner }; - eprintln!(" {}b: Bundle ρ_s={:.4} vs Trunc ρ_s={:.4} → {}", - bits, b.rho_structured, t.rho_structured, s_winner); + if s.rho_structured > b.rho_structured.max(t.rho_structured) { + "Sample" + } else { + winner + } + } else { + winner + }; + eprintln!( + " {}b: Bundle ρ_s={:.4} vs Trunc ρ_s={:.4} → {}", + bits, b.rho_structured, t.rho_structured, s_winner + ); } } // CyclicBundle should have reasonable ρ on structured data - let best_bundle_struct = results.iter() + let best_bundle_struct = results + .iter() .filter(|r| r.method == "CyclicBundle") .map(|r| r.rho_structured) .fold(0.0f64, f64::max); @@ -1662,8 +1841,10 @@ mod tests { eprintln!("\n╔══════════════════════════════════════════════════════════════════════╗"); eprintln!("║ NATIVE BINARY SWEET SPOT: Clustered SPO planes ║"); eprintln!("╠══════════════════════════════════════════════════════════════════════╣"); - eprintln!("║ {:>7} │ {:>6} │ {:>8} │ {:>7} │ {:>8} │ {:>12} ║", - "Bits", "Ratio", "ρ", "R@10", "Purity", "Δρ/bit×1e6"); + eprintln!( + "║ {:>7} │ {:>6} │ {:>8} │ {:>7} │ {:>8} │ {:>12} ║", + "Bits", "Ratio", "ρ", "R@10", "Purity", "Δρ/bit×1e6" + ); eprintln!("╠══════════════════════════════════════════════════════════════════════╣"); let bit_widths = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]; @@ -1716,9 +1897,16 @@ mod tests { "" }; - eprintln!("║ {:>7} │ {:>5.1}× │ {:>8.4} │ {:>7.2} │ {:>8.2} │ {:>11.2}{} ║", - target_bits, source_bits as f64 / target_bits as f64, - rho, recall, purity, marginal * 1e6, marker); + eprintln!( + "║ {:>7} │ {:>5.1}× │ {:>8.4} │ {:>7.2} │ {:>8.2} │ {:>11.2}{} ║", + target_bits, + source_bits as f64 / target_bits as f64, + rho, + recall, + purity, + marginal * 1e6, + marker + ); prev_rho = rho; prev_bits = target_bits; @@ -1726,8 +1914,7 @@ mod tests { eprintln!("╚══════════════════════════════════════════════════════════════════════╝"); if sweet_spot > 0 { - eprintln!(" >>> Sweet spot: {} bits ({}× compression)", - sweet_spot, source_bits / sweet_spot); + eprintln!(" >>> Sweet spot: {} bits ({}× compression)", sweet_spot, source_bits / sweet_spot); } } } diff --git a/src/hpc/crystal_encoder.rs b/src/hpc/crystal_encoder.rs index f6979370..818a40aa 100644 --- a/src/hpc/crystal_encoder.rs +++ b/src/hpc/crystal_encoder.rs @@ -170,11 +170,7 @@ impl CrystalEncoder { /// CrystalEncoder::absorb_into_node(&fp, &mut node, Role::Subject); /// assert!(node.s.encounters() > 0); /// ``` - pub fn absorb_into_node( - fingerprint: &Fingerprint, - node: &mut Node, - role: Role, - ) { + pub fn absorb_into_node(fingerprint: &Fingerprint, node: &mut Node, role: Role) { let plane: &mut Plane = match role { Role::Subject => &mut node.s, Role::Predicate => &mut node.p, @@ -211,12 +207,7 @@ impl CrystalEncoder { /// assert!(node.s.encounters() > 0); /// assert!(!fp.is_zero()); /// ``` - pub fn encode_and_absorb( - &self, - embedding: &[f32], - node: &mut Node, - role: Role, - ) -> Fingerprint { + pub fn encode_and_absorb(&self, embedding: &[f32], node: &mut Node, role: Role) -> Fingerprint { let fp = self.encode_embedding(embedding); Self::absorb_into_node(&fp, node, role); fp @@ -245,11 +236,7 @@ impl CrystalEncoder { /// assert!(!results.is_empty()); /// assert_eq!(results[0].1, 0); // exact match should have distance 0 /// ``` - pub fn search_similar( - query: &mut Node, - database: &mut [Node], - top_k: usize, - ) -> Vec<(usize, u32)> { + pub fn search_similar(query: &mut Node, database: &mut [Node], top_k: usize) -> Vec<(usize, u32)> { let mut results: Vec<(usize, u32)> = database .iter_mut() .enumerate() @@ -301,11 +288,7 @@ impl CrystalEncoder { /// assert!(!results.is_empty()); /// ``` pub fn pipeline_encode_search( - encoder: &CrystalEncoder, - subject_emb: &[f32], - predicate_emb: &[f32], - object_emb: &[f32], - database: &mut [Node], + encoder: &CrystalEncoder, subject_emb: &[f32], predicate_emb: &[f32], object_emb: &[f32], database: &mut [Node], top_k: usize, ) -> (Node, Vec<(usize, u32)>) { let mut query_node = Node::new(); @@ -379,11 +362,7 @@ fn distillation_loss(student: &mut Node, teacher: &mut Node, lambda: f32) -> f32 /// let losses = distill(&[teacher], &mut student, 3); /// assert_eq!(losses.len(), 3); /// ``` -pub fn distill( - teacher_nodes: &[Node], - student_encoder: &mut CrystalEncoder, - epochs: usize, -) -> Vec { +pub fn distill(teacher_nodes: &[Node], student_encoder: &mut CrystalEncoder, epochs: usize) -> Vec { let lambda = 0.1f32; let dim = student_encoder.embedding_dim; let total_weights = student_encoder.projection.len(); @@ -394,7 +373,9 @@ pub fn distill( .iter() .map(|t| { let acc = t.s.acc(); - (0..dim).map(|i| acc[i % acc.len()] as f32 / 127.0).collect() + (0..dim) + .map(|i| acc[i % acc.len()] as f32 / 127.0) + .collect() }) .collect(); @@ -411,7 +392,9 @@ pub fn distill( let fp_s = student_encoder.encode_embedding(emb); let fp_p = { // Slightly rotated embedding for predicate - let shifted: Vec = emb.iter().enumerate() + let shifted: Vec = emb + .iter() + .enumerate() .map(|(i, &v)| if i % 2 == 0 { v } else { -v }) .collect(); student_encoder.encode_embedding(&shifted) @@ -438,7 +421,9 @@ pub fn distill( student_encoder.flip_weight(idx); let fp_s2 = student_encoder.encode_embedding(emb); - let shifted: Vec = emb.iter().enumerate() + let shifted: Vec = emb + .iter() + .enumerate() .map(|(i, &v)| if i % 2 == 0 { v } else { -v }) .collect(); let fp_p2 = student_encoder.encode_embedding(&shifted); @@ -476,38 +461,22 @@ pub fn distill( /// These are the universal semantic primitives proposed by Anna Wierzbicka. const NSM_PRIMES: [&str; 65] = [ // Substantives - "I", "you", "someone", "something", "people", "body", - // Determiners - "this", "the same", "other", "else", - // Quantifiers - "one", "two", "some", "all", "much", "many", - // Evaluators - "good", "bad", - // Descriptors - "big", "small", - // Mental predicates - "think", "know", "want", "feel", "see", "hear", - // Speech - "say", "words", "true", - // Actions/events/movement - "do", "happen", "move", - // Existence/possession - "there is", "have", - // Life/death - "live", "die", - // Time - "when", "now", "before", "after", "a long time", "a short time", "for some time", "moment", - // Space - "where", "here", "above", "below", "far", "near", "side", "inside", "touch", - // Logical concepts - "not", "maybe", "can", "because", "if", - // Intensifier/augmentor - "very", "more", - // Similarity - "like", "as", - // Taxonomy/partonomy - "kind of", "part of", - // Relational + "I", "you", "someone", "something", "people", "body", // Determiners + "this", "the same", "other", "else", // Quantifiers + "one", "two", "some", "all", "much", "many", // Evaluators + "good", "bad", // Descriptors + "big", "small", // Mental predicates + "think", "know", "want", "feel", "see", "hear", // Speech + "say", "words", "true", // Actions/events/movement + "do", "happen", "move", // Existence/possession + "there is", "have", // Life/death + "live", "die", // Time + "when", "now", "before", "after", "a long time", "a short time", "for some time", "moment", // Space + "where", "here", "above", "below", "far", "near", "side", "inside", "touch", // Logical concepts + "not", "maybe", "can", "because", "if", // Intensifier/augmentor + "very", "more", // Similarity + "like", "as", // Taxonomy/partonomy + "kind of", "part of", // Relational "way", ]; @@ -620,8 +589,8 @@ pub fn encode_word(word: &str, codebook: &NsmCodebook) -> Fingerprint // Select codebook entry: first 8 bytes of hash mod 65 let selector = u64::from_le_bytes([ - hash_bytes[0], hash_bytes[1], hash_bytes[2], hash_bytes[3], - hash_bytes[4], hash_bytes[5], hash_bytes[6], hash_bytes[7], + hash_bytes[0], hash_bytes[1], hash_bytes[2], hash_bytes[3], hash_bytes[4], hash_bytes[5], hash_bytes[6], + hash_bytes[7], ]); let prime_idx = (selector % NSM_PRIMES.len() as u64) as usize; let base = codebook.get_prime(prime_idx).clone(); @@ -675,8 +644,8 @@ pub fn encode_sentence(words: &[&str], codebook: &NsmCodebook) -> Node { #[cfg(test)] mod tests { - use super::*; use super::super::node::SPO; + use super::*; // -- Phase 1 tests ------------------------------------------------------- @@ -835,13 +804,7 @@ mod tests { let cb = NsmCodebook::new(); for i in 0..65 { for j in (i + 1)..65 { - assert_ne!( - cb.get_prime(i), - cb.get_prime(j), - "primes {} and {} should differ", - i, - j - ); + assert_ne!(cb.get_prime(i), cb.get_prime(j), "primes {} and {} should differ", i, j); } } } diff --git a/src/hpc/cyclic_bundle.rs b/src/hpc/cyclic_bundle.rs index 25bb1ea1..26304ee5 100644 --- a/src/hpc/cyclic_bundle.rs +++ b/src/hpc/cyclic_bundle.rs @@ -123,12 +123,7 @@ pub fn majority_vote_3(a: &[u64; N], b: &[u64; N], c: &[u64; N]) -> [u64; N] { } /// Bundle SPO triple: S stays at identity, P shifted by `shift`, O shifted by 2*shift. -pub fn bundle_spo( - s: &[u64; N], - p: &[u64; N], - o: &[u64; N], - shift: usize, -) -> [u64; N] { +pub fn bundle_spo(s: &[u64; N], p: &[u64; N], o: &[u64; N], shift: usize) -> [u64; N] { let p_shifted = cyclic_shift(p, shift); let o_shifted = cyclic_shift(o, 2 * shift); majority_vote_3(s, &p_shifted, &o_shifted) @@ -326,14 +321,8 @@ mod tests { ); // Expect ~75% accuracy for n=3 majority vote - assert!( - mean_all > 0.70, - "shift={shift}: mean accuracy {mean_all:.4} too low (expect ~0.75)" - ); - assert!( - mean_all < 0.80, - "shift={shift}: mean accuracy {mean_all:.4} unexpectedly high" - ); + assert!(mean_all > 0.70, "shift={shift}: mean accuracy {mean_all:.4} too low (expect ~0.75)"); + assert!(mean_all < 0.80, "shift={shift}: mean accuracy {mean_all:.4} unexpectedly high"); } } @@ -378,8 +367,10 @@ mod tests { vec2[i] ^= noisy[i]; // flip ~5% of bits } - let bundle_3130 = bundle_spo(&vec2, &random_128(trial as u64 + 100), &random_128(trial as u64 + 200), GOLDEN_SHIFT); - let bundle_3131 = bundle_spo(&vec2, &random_128(trial as u64 + 100), &random_128(trial as u64 + 200), GOLDEN_SHIFT_ODD); + let bundle_3130 = + bundle_spo(&vec2, &random_128(trial as u64 + 100), &random_128(trial as u64 + 200), GOLDEN_SHIFT); + let bundle_3131 = + bundle_spo(&vec2, &random_128(trial as u64 + 100), &random_128(trial as u64 + 200), GOLDEN_SHIFT_ODD); let rec_3130 = recover_s(&bundle_3130, GOLDEN_SHIFT); let rec_3131 = recover_s(&bundle_3131, GOLDEN_SHIFT_ODD); @@ -393,9 +384,7 @@ mod tests { let acc_3130 = 1.0 - avg_err_3130; let acc_3131 = 1.0 - avg_err_3131; - eprintln!( - "[Experiment 2] Noisy period-2: acc_3130={acc_3130:.4}, acc_3131={acc_3131:.4}" - ); + eprintln!("[Experiment 2] Noisy period-2: acc_3130={acc_3130:.4}, acc_3131={acc_3131:.4}"); // Test 3: Period-4 vectors let mut period4 = [0u64; N]; @@ -424,14 +413,8 @@ mod tests { // Both shifts should give reasonable accuracy on noisy data // The key question: does 3130 fail on periodic inputs? // With random P and O, even periodic S should bundle reasonably. - assert!( - acc_3130 > 0.60, - "shift=3130 accuracy {acc_3130:.4} on noisy period-2 is catastrophically low" - ); - assert!( - acc_3131 > 0.60, - "shift=3131 accuracy {acc_3131:.4} on noisy period-2 is catastrophically low" - ); + assert!(acc_3130 > 0.60, "shift=3130 accuracy {acc_3130:.4} on noisy period-2 is catastrophically low"); + assert!(acc_3131 > 0.60, "shift=3131 accuracy {acc_3131:.4} on noisy period-2 is catastrophically low"); } // ── EXPERIMENT 3: Ranking preservation ───────────────────────── @@ -486,10 +469,7 @@ mod tests { // With n=3 majority vote and independent P,O per node, the noise from // P,O reduces rank correlation. rho ~ 0.20-0.30 is expected for // random independent triples. The ranking is partially preserved. - assert!( - rho > 0.15, - "Spearman rho={rho:.4} too low — ranking not preserved at all" - ); + assert!(rho > 0.15, "Spearman rho={rho:.4} too low — ranking not preserved at all"); } // ── EXPERIMENT 4: Search quality (Recall@10) ────────────────── @@ -565,10 +545,7 @@ mod tests { // We expect decent recall since similar S-planes should produce // closer bundles - assert!( - recall >= 0.3, - "Recall@{k}={recall:.2} is too low for practical use" - ); + assert!(recall >= 0.3, "Recall@{k}={recall:.2} is too low for practical use"); } // ── EXPERIMENT 5: CLAM clustering on cyclic bundles ──────────── @@ -622,8 +599,7 @@ mod tests { .map(|j| (j, hamming_128(&bundles[i], &bundles[j]))) .collect(); bundle_dists.sort_by_key(|&(_, d)| d); - let bundle_knn: Vec = - bundle_dists.iter().take(kk).map(|&(j, _)| j).collect(); + let bundle_knn: Vec = bundle_dists.iter().take(kk).map(|&(j, _)| j).collect(); // k-NN by original S distance let mut orig_dists: Vec<(usize, u32)> = (0..total_nodes) @@ -631,8 +607,7 @@ mod tests { .map(|j| (j, hamming_128(&s_planes[i], &s_planes[j]))) .collect(); orig_dists.sort_by_key(|&(_, d)| d); - let orig_knn: Vec = - orig_dists.iter().take(kk).map(|&(j, _)| j).collect(); + let orig_knn: Vec = orig_dists.iter().take(kk).map(|&(j, _)| j).collect(); // Purity: fraction of k-NN sharing same label let bundle_purity = bundle_knn @@ -640,11 +615,7 @@ mod tests { .filter(|&&j| labels[j] == labels[i]) .count() as f64 / kk as f64; - let orig_purity = orig_knn - .iter() - .filter(|&&j| labels[j] == labels[i]) - .count() as f64 - / kk as f64; + let orig_purity = orig_knn.iter().filter(|&&j| labels[j] == labels[i]).count() as f64 / kk as f64; // k-NN recall: how many of the true k-NN are in bundle k-NN let knn_recall = orig_knn @@ -723,9 +694,7 @@ mod tests { let actual_p = acc_p_sum / num_trials as f64; let actual_o = acc_o_sum / num_trials as f64; - eprintln!( - " {p:.2} | {predicted:.4} | {actual_s:.4} | {actual_p:.4} | {actual_o:.4}" - ); + eprintln!(" {p:.2} | {predicted:.4} | {actual_s:.4} | {actual_p:.4} | {actual_o:.4}"); // Verify predicted vs actual are reasonably close // The formula P(correct) = 1 - p(1-p) is for two independent random diff --git a/src/hpc/deepnsm.rs b/src/hpc/deepnsm.rs index b855cab9..86daf045 100644 --- a/src/hpc/deepnsm.rs +++ b/src/hpc/deepnsm.rs @@ -15,37 +15,95 @@ use std::sync::LazyLock; #[repr(u8)] pub enum NsmPrime { // Substantives - I = 0, You, Someone, Something, Thing, Body, + I = 0, + You, + Someone, + Something, + Thing, + Body, // Relational - Kind, Part, + Kind, + Part, // Determiners - This, TheSame, Other, Else, Another, + This, + TheSame, + Other, + Else, + Another, // Quantifiers - One, Two, Some, All, Much, Many, Little, Few, + One, + Two, + Some, + All, + Much, + Many, + Little, + Few, // Evaluators - Good, Bad, + Good, + Bad, // Descriptors - Big, Small, + Big, + Small, // Mental - Think, Know, Want, DontWant, Feel, See, Hear, + Think, + Know, + Want, + DontWant, + Feel, + See, + Hear, // Speech - Say, Words, True, + Say, + Words, + True, // Actions - Do, Happen, Move, + Do, + Happen, + Move, // Existence - Be, ThereIs, BeSomeone, Mine, + Be, + ThereIs, + BeSomeone, + Mine, // Life - Live, Die, + Live, + Die, // Time - When, Time, Now, Before, After, ALongTime, AShortTime, ForSomeTime, Moment, + When, + Time, + Now, + Before, + After, + ALongTime, + AShortTime, + ForSomeTime, + Moment, // Space - Where, Place, Here, Above, Below, Far, Near, Side, Inside, Touch, Contact, + Where, + Place, + Here, + Above, + Below, + Far, + Near, + Side, + Inside, + Touch, + Contact, // Logical - Not, Maybe, Can, Because, If, + Not, + Maybe, + Can, + Because, + If, // Intensifier - Very, More, + Very, + More, // Similarity - Like, As, Way, + Like, + As, + Way, } // Total: 74 variants (indices 0..73) @@ -86,69 +144,104 @@ pub struct NsmEntry { static ALL_PRIMES: [NsmPrime; 74] = [ // Substantives (6) - NsmPrime::I, NsmPrime::You, NsmPrime::Someone, NsmPrime::Something, - NsmPrime::Thing, NsmPrime::Body, + NsmPrime::I, + NsmPrime::You, + NsmPrime::Someone, + NsmPrime::Something, + NsmPrime::Thing, + NsmPrime::Body, // Relational (2) - NsmPrime::Kind, NsmPrime::Part, + NsmPrime::Kind, + NsmPrime::Part, // Determiners (5) - NsmPrime::This, NsmPrime::TheSame, NsmPrime::Other, NsmPrime::Else, + NsmPrime::This, + NsmPrime::TheSame, + NsmPrime::Other, + NsmPrime::Else, NsmPrime::Another, // Quantifiers (8) - NsmPrime::One, NsmPrime::Two, NsmPrime::Some, - NsmPrime::All, NsmPrime::Much, NsmPrime::Many, NsmPrime::Little, + NsmPrime::One, + NsmPrime::Two, + NsmPrime::Some, + NsmPrime::All, + NsmPrime::Much, + NsmPrime::Many, + NsmPrime::Little, NsmPrime::Few, // Evaluators (2) - NsmPrime::Good, NsmPrime::Bad, + NsmPrime::Good, + NsmPrime::Bad, // Descriptors (2) - NsmPrime::Big, NsmPrime::Small, + NsmPrime::Big, + NsmPrime::Small, // Mental (7) - NsmPrime::Think, NsmPrime::Know, NsmPrime::Want, - NsmPrime::DontWant, NsmPrime::Feel, NsmPrime::See, NsmPrime::Hear, + NsmPrime::Think, + NsmPrime::Know, + NsmPrime::Want, + NsmPrime::DontWant, + NsmPrime::Feel, + NsmPrime::See, + NsmPrime::Hear, // Speech (3) - NsmPrime::Say, NsmPrime::Words, NsmPrime::True, + NsmPrime::Say, + NsmPrime::Words, + NsmPrime::True, // Actions (3) - NsmPrime::Do, NsmPrime::Happen, NsmPrime::Move, + NsmPrime::Do, + NsmPrime::Happen, + NsmPrime::Move, // Existence (4) - NsmPrime::Be, NsmPrime::ThereIs, NsmPrime::BeSomeone, NsmPrime::Mine, + NsmPrime::Be, + NsmPrime::ThereIs, + NsmPrime::BeSomeone, + NsmPrime::Mine, // Life (2) - NsmPrime::Live, NsmPrime::Die, + NsmPrime::Live, + NsmPrime::Die, // Time (9) - NsmPrime::When, NsmPrime::Time, NsmPrime::Now, NsmPrime::Before, - NsmPrime::After, NsmPrime::ALongTime, NsmPrime::AShortTime, - NsmPrime::ForSomeTime, NsmPrime::Moment, + NsmPrime::When, + NsmPrime::Time, + NsmPrime::Now, + NsmPrime::Before, + NsmPrime::After, + NsmPrime::ALongTime, + NsmPrime::AShortTime, + NsmPrime::ForSomeTime, + NsmPrime::Moment, // Space (11) - NsmPrime::Where, NsmPrime::Place, - NsmPrime::Here, NsmPrime::Above, NsmPrime::Below, NsmPrime::Far, - NsmPrime::Near, NsmPrime::Side, NsmPrime::Inside, NsmPrime::Touch, + NsmPrime::Where, + NsmPrime::Place, + NsmPrime::Here, + NsmPrime::Above, + NsmPrime::Below, + NsmPrime::Far, + NsmPrime::Near, + NsmPrime::Side, + NsmPrime::Inside, + NsmPrime::Touch, NsmPrime::Contact, // Logical (5) - NsmPrime::Not, NsmPrime::Maybe, NsmPrime::Can, - NsmPrime::Because, NsmPrime::If, + NsmPrime::Not, + NsmPrime::Maybe, + NsmPrime::Can, + NsmPrime::Because, + NsmPrime::If, // Intensifier (2) - NsmPrime::Very, NsmPrime::More, + NsmPrime::Very, + NsmPrime::More, // Similarity (3) - NsmPrime::Like, NsmPrime::As, NsmPrime::Way, + NsmPrime::Like, + NsmPrime::As, + NsmPrime::Way, ]; static PRIME_NAMES: [&str; 74] = [ - "I", "YOU", "SOMEONE", "SOMETHING", "THING", "BODY", - "KIND", "PART", - "THIS", "THE_SAME", "OTHER", "ELSE", "ANOTHER", - "ONE", "TWO", "SOME", "ALL", "MUCH", "MANY", "LITTLE", "FEW", - "GOOD", "BAD", - "BIG", "SMALL", - "THINK", "KNOW", "WANT", "DONT_WANT", "FEEL", "SEE", "HEAR", - "SAY", "WORDS", "TRUE", - "DO", "HAPPEN", "MOVE", - "BE", "THERE_IS", "BE_SOMEONE", "MINE", - "LIVE", "DIE", - "WHEN", "TIME", "NOW", "BEFORE", "AFTER", "A_LONG_TIME", "A_SHORT_TIME", - "FOR_SOME_TIME", "MOMENT", - "WHERE", "PLACE", "HERE", "ABOVE", "BELOW", "FAR", "NEAR", "SIDE", - "INSIDE", "TOUCH", "CONTACT", - "NOT", "MAYBE", "CAN", "BECAUSE", "IF", - "VERY", "MORE", - "LIKE", "AS", "WAY", + "I", "YOU", "SOMEONE", "SOMETHING", "THING", "BODY", "KIND", "PART", "THIS", "THE_SAME", "OTHER", "ELSE", + "ANOTHER", "ONE", "TWO", "SOME", "ALL", "MUCH", "MANY", "LITTLE", "FEW", "GOOD", "BAD", "BIG", "SMALL", "THINK", + "KNOW", "WANT", "DONT_WANT", "FEEL", "SEE", "HEAR", "SAY", "WORDS", "TRUE", "DO", "HAPPEN", "MOVE", "BE", + "THERE_IS", "BE_SOMEONE", "MINE", "LIVE", "DIE", "WHEN", "TIME", "NOW", "BEFORE", "AFTER", "A_LONG_TIME", + "A_SHORT_TIME", "FOR_SOME_TIME", "MOMENT", "WHERE", "PLACE", "HERE", "ABOVE", "BELOW", "FAR", "NEAR", "SIDE", + "INSIDE", "TOUCH", "CONTACT", "NOT", "MAYBE", "CAN", "BECAUSE", "IF", "VERY", "MORE", "LIKE", "AS", "WAY", ]; impl NsmPrime { @@ -160,33 +253,58 @@ impl NsmPrime { /// Category this prime belongs to. pub fn category(&self) -> NsmCategory { match *self { - NsmPrime::I | NsmPrime::You | NsmPrime::Someone - | NsmPrime::Something | NsmPrime::Thing | NsmPrime::Body => NsmCategory::Substantive, + NsmPrime::I + | NsmPrime::You + | NsmPrime::Someone + | NsmPrime::Something + | NsmPrime::Thing + | NsmPrime::Body => NsmCategory::Substantive, NsmPrime::Kind | NsmPrime::Part => NsmCategory::Relational, - NsmPrime::This | NsmPrime::TheSame | NsmPrime::Other - | NsmPrime::Else | NsmPrime::Another => NsmCategory::Determiner, - NsmPrime::One | NsmPrime::Two | NsmPrime::Some | NsmPrime::All - | NsmPrime::Much | NsmPrime::Many | NsmPrime::Little | NsmPrime::Few => { - NsmCategory::Quantifier + NsmPrime::This | NsmPrime::TheSame | NsmPrime::Other | NsmPrime::Else | NsmPrime::Another => { + NsmCategory::Determiner } + NsmPrime::One + | NsmPrime::Two + | NsmPrime::Some + | NsmPrime::All + | NsmPrime::Much + | NsmPrime::Many + | NsmPrime::Little + | NsmPrime::Few => NsmCategory::Quantifier, NsmPrime::Good | NsmPrime::Bad => NsmCategory::Evaluator, NsmPrime::Big | NsmPrime::Small => NsmCategory::Descriptor, - NsmPrime::Think | NsmPrime::Know | NsmPrime::Want | NsmPrime::DontWant - | NsmPrime::Feel | NsmPrime::See | NsmPrime::Hear => NsmCategory::Mental, + NsmPrime::Think + | NsmPrime::Know + | NsmPrime::Want + | NsmPrime::DontWant + | NsmPrime::Feel + | NsmPrime::See + | NsmPrime::Hear => NsmCategory::Mental, NsmPrime::Say | NsmPrime::Words | NsmPrime::True => NsmCategory::Speech, NsmPrime::Do | NsmPrime::Happen | NsmPrime::Move => NsmCategory::Action, - NsmPrime::Be | NsmPrime::ThereIs | NsmPrime::BeSomeone | NsmPrime::Mine => { - NsmCategory::Existence - } + NsmPrime::Be | NsmPrime::ThereIs | NsmPrime::BeSomeone | NsmPrime::Mine => NsmCategory::Existence, NsmPrime::Live | NsmPrime::Die => NsmCategory::Life, - NsmPrime::When | NsmPrime::Time | NsmPrime::Now | NsmPrime::Before - | NsmPrime::After | NsmPrime::ALongTime | NsmPrime::AShortTime - | NsmPrime::ForSomeTime | NsmPrime::Moment => NsmCategory::Time, - NsmPrime::Where | NsmPrime::Place | NsmPrime::Here | NsmPrime::Above - | NsmPrime::Below | NsmPrime::Far | NsmPrime::Near | NsmPrime::Side - | NsmPrime::Inside | NsmPrime::Touch | NsmPrime::Contact => NsmCategory::Space, - NsmPrime::Not | NsmPrime::Maybe | NsmPrime::Can | NsmPrime::Because - | NsmPrime::If => NsmCategory::Logical, + NsmPrime::When + | NsmPrime::Time + | NsmPrime::Now + | NsmPrime::Before + | NsmPrime::After + | NsmPrime::ALongTime + | NsmPrime::AShortTime + | NsmPrime::ForSomeTime + | NsmPrime::Moment => NsmCategory::Time, + NsmPrime::Where + | NsmPrime::Place + | NsmPrime::Here + | NsmPrime::Above + | NsmPrime::Below + | NsmPrime::Far + | NsmPrime::Near + | NsmPrime::Side + | NsmPrime::Inside + | NsmPrime::Touch + | NsmPrime::Contact => NsmCategory::Space, + NsmPrime::Not | NsmPrime::Maybe | NsmPrime::Can | NsmPrime::Because | NsmPrime::If => NsmCategory::Logical, NsmPrime::Very | NsmPrime::More => NsmCategory::Intensifier, NsmPrime::Like | NsmPrime::As | NsmPrime::Way => NsmCategory::Similarity, } @@ -763,8 +881,7 @@ mod tests { #[test] fn test_category_coverage() { - let categories: HashSet = - NsmPrime::all().iter().map(|p| p.category()).collect(); + let categories: HashSet = NsmPrime::all().iter().map(|p| p.category()).collect(); // 16 categories assert_eq!(categories.len(), 16); assert!(categories.contains(&NsmCategory::Substantive)); @@ -836,11 +953,7 @@ mod tests { #[test] fn test_vocabulary_has_entries() { let vocab = nsm_vocabulary(); - assert!( - vocab.len() >= 200, - "vocabulary should have ≥200 entries, got {}", - vocab.len() - ); + assert!(vocab.len() >= 200, "vocabulary should have ≥200 entries, got {}", vocab.len()); } #[test] @@ -856,29 +969,13 @@ mod tests { let dog = nsm_lookup("dog").unwrap(); let cat_primes: HashSet = cat.iter().map(|(p, _)| *p).collect(); let dog_primes: HashSet = dog.iter().map(|(p, _)| *p).collect(); - assert!( - cat_primes.contains(&NsmPrime::Something), - "cat should have SOMETHING" - ); - assert!( - cat_primes.contains(&NsmPrime::Live), - "cat should have LIVE" - ); - assert!( - dog_primes.contains(&NsmPrime::Something), - "dog should have SOMETHING" - ); - assert!( - dog_primes.contains(&NsmPrime::Live), - "dog should have LIVE" - ); + assert!(cat_primes.contains(&NsmPrime::Something), "cat should have SOMETHING"); + assert!(cat_primes.contains(&NsmPrime::Live), "cat should have LIVE"); + assert!(dog_primes.contains(&NsmPrime::Something), "dog should have SOMETHING"); + assert!(dog_primes.contains(&NsmPrime::Live), "dog should have LIVE"); // Shared primes let shared: HashSet<_> = cat_primes.intersection(&dog_primes).collect(); - assert!( - shared.len() >= 2, - "cat and dog should share ≥2 primes, shared: {:?}", - shared - ); + assert!(shared.len() >= 2, "cat and dog should share ≥2 primes, shared: {:?}", shared); } #[test] @@ -886,13 +983,7 @@ mod tests { for prime in NsmPrime::all() { let name = prime.name(); let recovered = NsmPrime::from_name(name); - assert_eq!( - recovered, - Some(*prime), - "roundtrip failed for {:?} (name={})", - prime, - name - ); + assert_eq!(recovered, Some(*prime), "roundtrip failed for {:?} (name={})", prime, name); } } } @@ -904,16 +995,15 @@ mod tests { /// The full NSM primes set including multi-word primes (from Python utils.py). static NSM_PRIMES_SET: LazyLock> = LazyLock::new(|| { [ - "i", "you", "someone", "people", "something", "thing", "body", "kind", "part", - "this", "the same", "other", "else", "another", "one", "two", "some", "all", - "much", "many", "little", "few", "good", "bad", "big", "small", "think", "know", - "want", "don't want", "feel", "see", "hear", "say", "words", "true", "do", - "happen", "move", "there", "is", "be", "mine", "live", "die", "when", "time", - "now", "before", "after", "a long time", "a short time", "for some time", - "moment", "where", "place", "here", "above", "below", "far", "near", "side", - "inside", "touch", "not", "maybe", "can", "because", "if", "very", "more", - "like", "as", "way", "said", - ].into_iter().collect() + "i", "you", "someone", "people", "something", "thing", "body", "kind", "part", "this", "the same", "other", + "else", "another", "one", "two", "some", "all", "much", "many", "little", "few", "good", "bad", "big", "small", + "think", "know", "want", "don't want", "feel", "see", "hear", "say", "words", "true", "do", "happen", "move", + "there", "is", "be", "mine", "live", "die", "when", "time", "now", "before", "after", "a long time", + "a short time", "for some time", "moment", "where", "place", "here", "above", "below", "far", "near", "side", + "inside", "touch", "not", "maybe", "can", "because", "if", "very", "more", "like", "as", "way", "said", + ] + .into_iter() + .collect() }); /// Check if a word is an NSM semantic prime. @@ -924,19 +1014,19 @@ pub fn is_nsm_prime(word: &str) -> bool { /// English stopwords excluding NSM primes. `LazyLock` one-time init. static STOP_WORDS: LazyLock> = LazyLock::new(|| { let sw: HashSet<&str> = [ - "a", "an", "and", "are", "at", "been", "but", "by", "did", "does", - "doing", "down", "during", "each", "for", "from", "further", "had", - "has", "having", "he", "her", "herself", "him", "himself", "his", - "how", "in", "into", "it", "its", "itself", "just", "me", "my", - "myself", "no", "nor", "of", "off", "on", "once", "only", "or", - "our", "ours", "ourselves", "out", "over", "own", "re", "s", "she", - "should", "so", "such", "t", "than", "that", "the", "their", - "theirs", "them", "themselves", "then", "these", "they", "those", - "through", "to", "too", "under", "until", "up", "ve", "was", "we", - "were", "what", "which", "while", "who", "whom", "why", "will", - "with", "won", "would", "your", "yours", "yourself", "yourselves", - ].into_iter().collect(); - sw.into_iter().filter(|w| !NSM_PRIMES_SET.contains(*w)).collect() + "a", "an", "and", "are", "at", "been", "but", "by", "did", "does", "doing", "down", "during", "each", "for", + "from", "further", "had", "has", "having", "he", "her", "herself", "him", "himself", "his", "how", "in", + "into", "it", "its", "itself", "just", "me", "my", "myself", "no", "nor", "of", "off", "on", "once", "only", + "or", "our", "ours", "ourselves", "out", "over", "own", "re", "s", "she", "should", "so", "such", "t", "than", + "that", "the", "their", "theirs", "them", "themselves", "then", "these", "they", "those", "through", "to", + "too", "under", "until", "up", "ve", "was", "we", "were", "what", "which", "while", "who", "whom", "why", + "will", "with", "won", "would", "your", "yours", "yourself", "yourselves", + ] + .into_iter() + .collect(); + sw.into_iter() + .filter(|w| !NSM_PRIMES_SET.contains(*w)) + .collect() }); /// Check if a word is a stopword (but not an NSM prime). @@ -961,7 +1051,13 @@ pub struct Prediction { impl Prediction { pub fn new(prediction: &str) -> Self { - Self { prediction: prediction.to_string(), answer_logprob: 0.0, answer_ranks: Vec::new(), is_match: false, lines_removed: 0 } + Self { + prediction: prediction.to_string(), + answer_logprob: 0.0, + answer_ranks: Vec::new(), + is_match: false, + lines_removed: 0, + } } } @@ -982,7 +1078,18 @@ pub struct SubstitutabilityScore { impl SubstitutabilityScore { pub fn new(model: &str) -> Self { - Self { model: model.to_string(), baselines: Vec::new(), exp_baselines: Vec::new(), minimality: Vec::new(), entailments: Vec::new(), adj_score: 0.0, avg_delta_log: 0.0, avg_min_delta_log: 0.0, avg_ent_delta_log: 0.0, total_match: 0 } + Self { + model: model.to_string(), + baselines: Vec::new(), + exp_baselines: Vec::new(), + minimality: Vec::new(), + entailments: Vec::new(), + adj_score: 0.0, + avg_delta_log: 0.0, + avg_min_delta_log: 0.0, + avg_ent_delta_log: 0.0, + total_match: 0, + } } } @@ -1010,44 +1117,86 @@ pub struct Explication { impl Explication { pub fn new(text: &str) -> Self { Self { - text: text.to_string(), target_word: String::new(), - length: 0, primes: 0, stop_words_count: 0, molecules: 0, - unique_molecules: 0, uses_original_word: false, primes_ratio: 0.0, - molecules_ratio: 0.0, sub_scores: Vec::new(), avg_delta: 0.0, - avg_delta_min: 0.0, avg_delta_ent: 0.0, score_exp: 0.0, total_score: 0.0, + text: text.to_string(), + target_word: String::new(), + length: 0, + primes: 0, + stop_words_count: 0, + molecules: 0, + unique_molecules: 0, + uses_original_word: false, + primes_ratio: 0.0, + molecules_ratio: 0.0, + sub_scores: Vec::new(), + avg_delta: 0.0, + avg_delta_min: 0.0, + avg_delta_ent: 0.0, + score_exp: 0.0, + total_score: 0.0, } } /// Score legality against a target word (circularity via stem matching). pub fn legality_score(&mut self, word: &str) { - let clean: String = self.text.to_lowercase().chars() - .filter(|c| c.is_alphanumeric() || c.is_whitespace()).collect(); + let clean: String = self + .text + .to_lowercase() + .chars() + .filter(|c| c.is_alphanumeric() || c.is_whitespace()) + .collect(); let tokens: Vec<&str> = clean.split_whitespace().collect(); self.target_word = word.to_string(); self.length = tokens.len(); self.primes = tokens.iter().filter(|t| is_nsm_prime(t)).count(); self.stop_words_count = tokens.iter().filter(|t| is_stop_word(t)).count(); - let mols: Vec<&&str> = tokens.iter().filter(|t| !is_nsm_prime(t) && !is_stop_word(t)).collect(); + let mols: Vec<&&str> = tokens + .iter() + .filter(|t| !is_nsm_prime(t) && !is_stop_word(t)) + .collect(); self.molecules = mols.len(); self.unique_molecules = mols.iter().collect::>().len(); let wl = word.to_lowercase(); let stem = if wl.len() >= 4 { &wl[..4] } else { &wl }; - self.uses_original_word = tokens.iter().any(|t| *t == wl || (t.len() >= 4 && t.starts_with(stem))); - self.primes_ratio = if self.length > 0 { self.primes as f32 / self.length as f32 } else { 0.0 }; - self.molecules_ratio = if self.length > 0 { self.molecules as f32 / self.length as f32 } else { 0.0 }; + self.uses_original_word = tokens + .iter() + .any(|t| *t == wl || (t.len() >= 4 && t.starts_with(stem))); + self.primes_ratio = if self.length > 0 { + self.primes as f32 / self.length as f32 + } else { + 0.0 + }; + self.molecules_ratio = if self.length > 0 { + self.molecules as f32 / self.length as f32 + } else { + 0.0 + }; } /// Compute averages from substitutability sub-scores. pub fn calculate_averages(&mut self) { - if self.sub_scores.is_empty() { return; } + if self.sub_scores.is_empty() { + return; + } let n = self.sub_scores.len() as f32; self.avg_delta = self.sub_scores.iter().map(|s| s.avg_delta_log).sum::() / n; - self.avg_delta_min = self.sub_scores.iter().map(|s| s.avg_min_delta_log).sum::() / n; - self.avg_delta_ent = self.sub_scores.iter().map(|s| s.avg_ent_delta_log).sum::() / n; + self.avg_delta_min = self + .sub_scores + .iter() + .map(|s| s.avg_min_delta_log) + .sum::() + / n; + self.avg_delta_ent = self + .sub_scores + .iter() + .map(|s| s.avg_ent_delta_log) + .sum::() + / n; self.score_exp = self.sub_scores.iter().map(|s| s.adj_score).sum::() / n; self.total_score = if !self.uses_original_word { 2.0 * (self.score_exp + 10.0 * self.primes_ratio - 10.0 * self.molecules_ratio) - } else { 0.0 }; + } else { + 0.0 + }; } /// Truncated versions with lines removed from the end. @@ -1067,17 +1216,39 @@ pub struct AmbiguousExample { } impl AmbiguousExample { - pub fn new(text: &str) -> Self { Self { text: text.to_string(), source: None } } + pub fn new(text: &str) -> Self { + Self { + text: text.to_string(), + source: None, + } + } /// Truncated versions removing non-UNK sentences. pub fn get_truncated(&self, max_remove: usize) -> Vec { - let sents: Vec<&str> = self.text.split('.').map(|s| s.trim()).filter(|s| !s.is_empty()).collect(); - let non_unk: Vec = sents.iter().enumerate().filter(|(_, s)| !s.contains("")).map(|(i, _)| i).collect(); - (0..max_remove.min(non_unk.len())).map(|i| { - let exclude: HashSet = non_unk[..=i].iter().copied().collect(); - let kept: Vec<&str> = sents.iter().enumerate().filter(|(j, _)| !exclude.contains(j)).map(|(_, s)| *s).collect(); - AmbiguousExample::new(&kept.join(". ")) - }).collect() + let sents: Vec<&str> = self + .text + .split('.') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect(); + let non_unk: Vec = sents + .iter() + .enumerate() + .filter(|(_, s)| !s.contains("")) + .map(|(i, _)| i) + .collect(); + (0..max_remove.min(non_unk.len())) + .map(|i| { + let exclude: HashSet = non_unk[..=i].iter().copied().collect(); + let kept: Vec<&str> = sents + .iter() + .enumerate() + .filter(|(j, _)| !exclude.contains(j)) + .map(|(_, s)| *s) + .collect(); + AmbiguousExample::new(&kept.join(". ")) + }) + .collect() } } @@ -1094,13 +1265,32 @@ pub struct ModelResult { impl ModelResult { pub fn new(model_name: &str) -> Self { - Self { model_name: model_name.to_string(), num_examples: 0, explications: Vec::new(), avg_primes_ratio: 0.0, avg_molecules_ratio: 0.0, avg_total_score: 0.0 } + Self { + model_name: model_name.to_string(), + num_examples: 0, + explications: Vec::new(), + avg_primes_ratio: 0.0, + avg_molecules_ratio: 0.0, + avg_total_score: 0.0, + } } pub fn calculate_averages(&mut self) { let n = self.explications.len() as f32; - if n == 0.0 { return; } - self.avg_primes_ratio = self.explications.iter().map(|e| e.primes_ratio).sum::() / n; - self.avg_molecules_ratio = self.explications.iter().map(|e| e.molecules_ratio).sum::() / n; + if n == 0.0 { + return; + } + self.avg_primes_ratio = self + .explications + .iter() + .map(|e| e.primes_ratio) + .sum::() + / n; + self.avg_molecules_ratio = self + .explications + .iter() + .map(|e| e.molecules_ratio) + .sum::() + / n; self.avg_total_score = self.explications.iter().map(|e| e.total_score).sum::() / n; } } @@ -1111,7 +1301,12 @@ impl ModelResult { pub fn load_nsm_codebook(codebook_bytes: &[u8]) -> super::cam_pq::CamCodebook { use super::cam_pq::{CamCodebook, SubspaceCodebook, NUM_CENTROIDS, NUM_SUBSPACES}; let expected = NUM_SUBSPACES * NUM_CENTROIDS * 16 * 4; - assert_eq!(codebook_bytes.len(), expected, "codebook_pq.bin: expected {expected} bytes, got {}", codebook_bytes.len()); + assert_eq!( + codebook_bytes.len(), + expected, + "codebook_pq.bin: expected {expected} bytes, got {}", + codebook_bytes.len() + ); let mut codebooks: Vec = Vec::with_capacity(NUM_SUBSPACES); for s in 0..NUM_SUBSPACES { let mut centroids = Vec::with_capacity(NUM_CENTROIDS); @@ -1119,35 +1314,64 @@ pub fn load_nsm_codebook(codebook_bytes: &[u8]) -> super::cam_pq::CamCodebook { let mut centroid = Vec::with_capacity(16); for d in 0..16 { let off = (s * NUM_CENTROIDS * 16 + c * 16 + d) * 4; - centroid.push(f32::from_le_bytes([codebook_bytes[off], codebook_bytes[off+1], codebook_bytes[off+2], codebook_bytes[off+3]])); + centroid.push(f32::from_le_bytes([ + codebook_bytes[off], + codebook_bytes[off + 1], + codebook_bytes[off + 2], + codebook_bytes[off + 3], + ])); } centroids.push(centroid); } - codebooks.push(SubspaceCodebook { centroids, subspace_dim: 16 }); + codebooks.push(SubspaceCodebook { + centroids, + subspace_dim: 16, + }); + } + CamCodebook { + codebooks: codebooks.try_into().unwrap(), + total_dim: 96, + subspace_dim: 16, } - CamCodebook { codebooks: codebooks.try_into().unwrap(), total_dim: 96, subspace_dim: 16 } } /// Load CAM codes (`cam_codes.bin`): N words × 6 bytes. pub fn load_cam_codes(bytes: &[u8]) -> Vec { assert_eq!(bytes.len() % 6, 0); - bytes.chunks_exact(6).map(|c| { let mut fp = [0u8; 6]; fp.copy_from_slice(c); fp }).collect() + bytes + .chunks_exact(6) + .map(|c| { + let mut fp = [0u8; 6]; + fp.copy_from_slice(c); + fp + }) + .collect() } // ── 36-bit SPO triple ─────────────────────────────────────────────────────── /// 36-bit SPO triple packed in u64. 12-bit subject + predicate + object. #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub struct SpoTriple { packed: u64 } +pub struct SpoTriple { + packed: u64, +} impl SpoTriple { pub fn new(subject: u16, predicate: u16, object: u16) -> Self { debug_assert!(subject < 4096 && predicate < 4096 && object < 4096); - Self { packed: ((subject as u64) << 24) | ((predicate as u64) << 12) | object as u64 } + Self { + packed: ((subject as u64) << 24) | ((predicate as u64) << 12) | object as u64, + } + } + pub fn subject(&self) -> u16 { + ((self.packed >> 24) & 0xFFF) as u16 + } + pub fn predicate(&self) -> u16 { + ((self.packed >> 12) & 0xFFF) as u16 + } + pub fn object(&self) -> u16 { + (self.packed & 0xFFF) as u16 } - pub fn subject(&self) -> u16 { ((self.packed >> 24) & 0xFFF) as u16 } - pub fn predicate(&self) -> u16 { ((self.packed >> 12) & 0xFFF) as u16 } - pub fn object(&self) -> u16 { (self.packed & 0xFFF) as u16 } } // ── Prompt templates ──────────────────────────────────────────────────────── @@ -1156,20 +1380,38 @@ impl SpoTriple { pub const NSM_EXPLICATION_SYS_INST: &str = "You are a linguist specializing in semantic analysis using the Natural Semantic Metalanguage (NSM) approach. NSM reduces lexicons to universal semantic primes. Paraphrase the word's meaning using NSM primes."; /// Recovery prompt: predict masked word. -pub const RECOVERY_PROMPT_SYS_INST: &str = "Read the passage with a missing word indicated by . Predict the missing word. Output only your prediction."; +pub const RECOVERY_PROMPT_SYS_INST: &str = + "Read the passage with a missing word indicated by . Predict the missing word. Output only your prediction."; /// Chat message for prompt construction. #[derive(Clone, Debug)] -pub struct ChatMessage { pub role: String, pub content: String } +pub struct ChatMessage { + pub role: String, + pub content: String, +} /// Build explication prompt with optional few-shot. -pub fn build_explication_prompt(word: &str, examples: &[&str], few_shot: &[(String, String)], max: Option) -> Vec { - let mut msgs = vec![ChatMessage { role: "system".into(), content: NSM_EXPLICATION_SYS_INST.into() }]; +pub fn build_explication_prompt( + word: &str, examples: &[&str], few_shot: &[(String, String)], max: Option, +) -> Vec { + let mut msgs = vec![ChatMessage { + role: "system".into(), + content: NSM_EXPLICATION_SYS_INST.into(), + }]; for (u, a) in &few_shot[..max.unwrap_or(few_shot.len()).min(few_shot.len())] { - msgs.push(ChatMessage { role: "user".into(), content: u.clone() }); - msgs.push(ChatMessage { role: "assistant".into(), content: a.clone() }); + msgs.push(ChatMessage { + role: "user".into(), + content: u.clone(), + }); + msgs.push(ChatMessage { + role: "assistant".into(), + content: a.clone(), + }); } - msgs.push(ChatMessage { role: "user".into(), content: format!("Word: {word}\nExamples:\n{}\nParaphrase:", examples.join("\n\n")) }); + msgs.push(ChatMessage { + role: "user".into(), + content: format!("Word: {word}\nExamples:\n{}\nParaphrase:", examples.join("\n\n")), + }); msgs } @@ -1180,8 +1422,14 @@ pub fn build_recover_prompt(ambig: &AmbiguousExample, exp: Option<&Explication>) None => format!("Passage: {}\nMissing Word:", ambig.text), }; vec![ - ChatMessage { role: "system".into(), content: RECOVERY_PROMPT_SYS_INST.into() }, - ChatMessage { role: "user".into(), content: user }, + ChatMessage { + role: "system".into(), + content: RECOVERY_PROMPT_SYS_INST.into(), + }, + ChatMessage { + role: "user".into(), + content: user, + }, ] } @@ -1265,10 +1513,10 @@ mod eval_tests { #[test] fn test_cam_codes_load() { - let bytes = vec![1,2,3,4,5,6, 7,8,9,10,11,12]; + let bytes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let codes = load_cam_codes(&bytes); assert_eq!(codes.len(), 2); - assert_eq!(codes[0], [1,2,3,4,5,6]); + assert_eq!(codes[0], [1, 2, 3, 4, 5, 6]); } #[test] diff --git a/src/hpc/distance.rs b/src/hpc/distance.rs index 545a6f22..7ef3c6d7 100644 --- a/src/hpc/distance.rs +++ b/src/hpc/distance.rs @@ -38,11 +38,7 @@ pub(crate) mod simd_impl { /// # Safety /// Caller must ensure AVX2 is available. #[target_feature(enable = "avx2")] - pub(crate) unsafe fn squared_distances_avx2( - query: [f32; 3], - points: &[[f32; 3]], - out: &mut Vec, - ) { + pub(crate) unsafe fn squared_distances_avx2(query: [f32; 3], points: &[[f32; 3]], out: &mut Vec) { let n = points.len(); out.clear(); out.reserve(n); @@ -114,11 +110,7 @@ pub fn squared_distances_f32(query: [f32; 3], points: &[[f32; 3]]) -> Vec { } /// Filter points by max squared distance. Returns indices of survivors. -pub fn filter_by_radius_sq( - query: [f32; 3], - points: &[[f32; 3]], - radius_sq: f32, -) -> Vec { +pub fn filter_by_radius_sq(query: [f32; 3], points: &[[f32; 3]], radius_sq: f32) -> Vec { let dists = squared_distances_f32(query, points); dists .iter() @@ -129,11 +121,7 @@ pub fn filter_by_radius_sq( /// Find K nearest points (f32). Returns `(indices, squared_distances)` sorted /// ascending by distance. -pub fn knn_f32( - query: [f32; 3], - points: &[[f32; 3]], - k: usize, -) -> (Vec, Vec) { +pub fn knn_f32(query: [f32; 3], points: &[[f32; 3]], k: usize) -> (Vec, Vec) { let dists = squared_distances_f32(query, points); let mut indexed: Vec<(usize, f32)> = dists.into_iter().enumerate().collect(); indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); @@ -156,11 +144,7 @@ pub fn squared_distances_f64(query: [f64; 3], points: &[[f64; 3]]) -> Vec { } /// Filter f64 points by squared-distance radius. Returns survivor indices. -pub fn filter_by_radius_sq_f64( - query: [f64; 3], - points: &[[f64; 3]], - radius_sq: f64, -) -> Vec { +pub fn filter_by_radius_sq_f64(query: [f64; 3], points: &[[f64; 3]], radius_sq: f64) -> Vec { let dists = squared_distances_f64(query, points); dists .iter() @@ -171,11 +155,7 @@ pub fn filter_by_radius_sq_f64( /// Find K nearest points (f64). Returns `(indices, squared_distances)` sorted /// ascending by distance. -pub fn knn_f64( - query: [f64; 3], - points: &[[f64; 3]], - k: usize, -) -> (Vec, Vec) { +pub fn knn_f64(query: [f64; 3], points: &[[f64; 3]], k: usize) -> (Vec, Vec) { let dists = squared_distances_f64(query, points); let mut indexed: Vec<(usize, f64)> = dists.into_iter().enumerate().collect(); indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); @@ -216,10 +196,7 @@ mod tests { assert_eq!(result.len(), points.len()); for (i, &d) in result.iter().enumerate() { let expected = sq_dist_f32(query, points[i]); - assert!( - approx_eq_f32(d, expected), - "mismatch at {i}: {d} vs {expected}" - ); + assert!(approx_eq_f32(d, expected), "mismatch at {i}: {d} vs {expected}"); } } @@ -235,10 +212,7 @@ mod tests { let result = squared_distances_f64(query, &points); for (i, &d) in result.iter().enumerate() { let expected = sq_dist_f64(query, points[i]); - assert!( - approx_eq_f64(d, expected), - "mismatch at {i}: {d} vs {expected}" - ); + assert!(approx_eq_f64(d, expected), "mismatch at {i}: {d} vs {expected}"); } } @@ -286,12 +260,7 @@ mod tests { #[test] fn test_knn_f32() { let query = [0.0f32, 0.0, 0.0]; - let points = vec![ - [3.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [2.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - ]; + let points = vec![[3.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.5, 0.0, 0.0]]; let (idx, dist) = knn_f32(query, &points, 2); assert_eq!(idx, vec![3, 1]); // 0.25, 1.0 assert!(approx_eq_f32(dist[0], 0.25)); @@ -301,12 +270,7 @@ mod tests { #[test] fn test_knn_f64() { let query = [0.0f64, 0.0, 0.0]; - let points = vec![ - [3.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [2.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - ]; + let points = vec![[3.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.5, 0.0, 0.0]]; let (idx, dist) = knn_f64(query, &points, 2); assert_eq!(idx, vec![3, 1]); assert!(approx_eq_f64(dist[0], 0.25)); diff --git a/src/hpc/dn_tree.rs b/src/hpc/dn_tree.rs index 164f6fd4..1f472583 100644 --- a/src/hpc/dn_tree.rs +++ b/src/hpc/dn_tree.rs @@ -56,11 +56,7 @@ impl SplitMix64 { /// Create a random `GraphHV` from the local RNG. #[cfg(test)] fn random_graphhv(rng: &mut SplitMix64) -> GraphHV { - let mut channels = [ - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - Fingerprint::<256>::zero(), - ]; + let mut channels = [Fingerprint::<256>::zero(), Fingerprint::<256>::zero(), Fingerprint::<256>::zero()]; for ch in &mut channels { for w in ch.words.iter_mut() { *w = rng.next_u64(); @@ -109,13 +105,7 @@ fn is_zero(hv: &GraphHV) -> bool { /// For each bit position, the existing summary bit is kept with probability /// `1 - lr * boost`, and replaced by the input bit with probability `lr * boost`. /// This is implemented per-word using a stochastic mask. -fn bundle_into( - current: &GraphHV, - hv: &GraphHV, - lr: f64, - boost: f64, - rng: &mut SplitMix64, -) -> GraphHV { +fn bundle_into(current: &GraphHV, hv: &GraphHV, lr: f64, boost: f64, rng: &mut SplitMix64) -> GraphHV { let effective_lr = (lr * boost).min(1.0); let mut result = current.clone(); @@ -125,8 +115,7 @@ fn bundle_into( // Approximate by AND-ing multiple random words (each AND halves probability). let mask = make_probability_mask(effective_lr, rng); // Where mask is 1: take from hv; where 0: keep current. - result.channels[ch].words[w] = - (current.channels[ch].words[w] & !mask) | (hv.channels[ch].words[w] & mask); + result.channels[ch].words[w] = (current.channels[ch].words[w] & !mask) | (hv.channels[ch].words[w] & mask); } } @@ -330,12 +319,11 @@ impl DNTree { ); // Determine BTSP boost for this update - let btsp_boost = - if self.config.btsp_gate_prob > 0.0 && rng.next_f64() < self.config.btsp_gate_prob { - self.config.btsp_boost - } else { - 1.0 - }; + let btsp_boost = if self.config.btsp_gate_prob > 0.0 && rng.next_f64() < self.config.btsp_gate_prob { + self.config.btsp_boost + } else { + 1.0 + }; let lr = self.config.learning_rate; let growth = self.config.growth_factor; @@ -409,8 +397,7 @@ impl DNTree { .iter() .filter(|&&c| self.nodes[c].access_count > 0) .map(|&c| { - let sim = - partial_similarity(query, &self.summaries[c], self.config.partial_bits); + let sim = partial_similarity(query, &self.summaries[c], self.config.partial_bits); (c, sim) }) .collect(); @@ -420,8 +407,7 @@ impl DNTree { } // Sort by similarity (descending) - child_sims - .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + child_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); // Early exit: if best child exceeds threshold, prune beam let best_sim = child_sims[0].1; @@ -592,11 +578,7 @@ mod tests { assert!(!is_zero(&tree.summaries[0])); let sim = similarity(&hv, &tree.summaries[0]); - assert!( - sim > 0.5, - "After 20 updates, summary should resemble input: sim={:.4}", - sim - ); + assert!(sim > 0.5, "After 20 updates, summary should resemble input: sim={:.4}", sim); } #[test] @@ -617,12 +599,7 @@ mod tests { tree.update(10, &hv, &mut rng); } - assert!( - tree.num_nodes() > initial_nodes, - "Expected split: {} -> {}", - initial_nodes, - tree.num_nodes() - ); + assert!(tree.num_nodes() > initial_nodes, "Expected split: {} -> {}", initial_nodes, tree.num_nodes()); } #[test] @@ -718,11 +695,7 @@ mod tests { } let sim = similarity(&hv, &tree.summaries[0]); - assert!( - sim > 0.5, - "BTSP boost should accelerate learning: sim={:.4}", - sim - ); + assert!(sim > 0.5, "BTSP boost should accelerate learning: sim={:.4}", sim); } #[test] diff --git a/src/hpc/fft.rs b/src/hpc/fft.rs index 30b9df99..fa3326cc 100644 --- a/src/hpc/fft.rs +++ b/src/hpc/fft.rs @@ -311,8 +311,7 @@ mod tests { let norm_before: f32 = data.iter().map(|x| x * x).sum::().sqrt(); wht_f32(&mut data); let norm_after: f32 = data.iter().map(|x| x * x).sum::().sqrt(); - assert!((norm_before - norm_after).abs() < 1e-4, - "energy: {} vs {}", norm_before, norm_after); + assert!((norm_before - norm_after).abs() < 1e-4, "energy: {} vs {}", norm_before, norm_after); } #[test] @@ -323,12 +322,14 @@ mod tests { // Norm preservation at 1024-d (hits SIMD path) let n_orig: f32 = original.iter().map(|x| x * x).sum::().sqrt(); let n_wht: f32 = data.iter().map(|x| x * x).sum::().sqrt(); - assert!((n_orig - n_wht).abs() / n_orig < 1e-4, - "SIMD WHT norm: {} vs {}", n_orig, n_wht); + assert!((n_orig - n_wht).abs() / n_orig < 1e-4, "SIMD WHT norm: {} vs {}", n_orig, n_wht); // Self-inverse wht_f32(&mut data); - let max_err = original.iter().zip(data.iter()) - .map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max); + let max_err = original + .iter() + .zip(data.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); assert!(max_err < 1e-3, "SIMD self-inverse max_err: {}", max_err); } @@ -341,5 +342,4 @@ mod tests { // DC component: sum = 10 assert!((output[0] - 10.0).abs() < 1e-4); } - } diff --git a/src/hpc/fingerprint.rs b/src/hpc/fingerprint.rs index 66679575..333f2047 100644 --- a/src/hpc/fingerprint.rs +++ b/src/hpc/fingerprint.rs @@ -39,9 +39,7 @@ impl Fingerprint { /// All-ones fingerprint. #[inline] pub fn ones() -> Self { - Self { - words: [u64::MAX; N], - } + Self { words: [u64::MAX; N] } } /// Create from a word array. @@ -52,12 +50,7 @@ impl Fingerprint { /// Create from a byte slice. Panics if `bytes.len() < N * 8`. pub fn from_bytes(bytes: &[u8]) -> Self { - assert!( - bytes.len() >= N * 8, - "need at least {} bytes, got {}", - N * 8, - bytes.len() - ); + assert!(bytes.len() >= N * 8, "need at least {} bytes, got {}", N * 8, bytes.len()); let mut words = [0u64; N]; for i in 0..N { let offset = i * 8; @@ -150,7 +143,9 @@ impl Fingerprint { #[inline] pub fn bind(&self, other: &Self) -> Self { let mut words = [0u64; N]; - for i in 0..N { words[i] = self.words[i] ^ other.words[i]; } + for i in 0..N { + words[i] = self.words[i] ^ other.words[i]; + } Self { words } } @@ -158,7 +153,9 @@ impl Fingerprint { #[inline] pub fn and(&self, other: &Self) -> Self { let mut words = [0u64; N]; - for i in 0..N { words[i] = self.words[i] & other.words[i]; } + for i in 0..N { + words[i] = self.words[i] & other.words[i]; + } Self { words } } @@ -166,7 +163,9 @@ impl Fingerprint { #[inline] pub fn not(&self) -> Self { let mut words = [0u64; N]; - for i in 0..N { words[i] = !self.words[i]; } + for i in 0..N { + words[i] = !self.words[i]; + } Self { words } } @@ -207,12 +206,15 @@ impl Fingerprint { /// half of the input fingerprints have it set. pub fn bundle(items: &[&Self]) -> Self { let n = items.len(); - if n == 0 { return Self::zero(); } + if n == 0 { + return Self::zero(); + } let threshold = n / 2; let mut result = [0u64; N]; for w in 0..N { for bit in 0..64 { - let count: usize = items.iter() + let count: usize = items + .iter() .filter(|fp| (fp.words[w] >> bit) & 1 == 1) .count(); if count > threshold { @@ -233,7 +235,9 @@ impl Fingerprint { #[inline] pub fn or(&self, other: &Self) -> Self { let mut words = [0u64; N]; - for i in 0..N { words[i] = self.words[i] | other.words[i]; } + for i in 0..N { + words[i] = self.words[i] | other.words[i]; + } Self { words } } @@ -251,7 +255,9 @@ impl Fingerprint { pub fn permute(&self, positions: i32) -> Self { let total = Self::BITS as i32; let shift = ((positions % total) + total) % total; - if shift == 0 { return self.clone(); } + if shift == 0 { + return self.clone(); + } let mut result = Self::zero(); for i in 0..Self::BITS { if self.get_bit(i) { @@ -469,7 +475,12 @@ pub struct VectorConfig { impl VectorConfig { const fn from_width(w: VectorWidth) -> Self { let words = w as usize; - VectorConfig { width: w, words, bits: words * 64, bytes: words * 8 } + VectorConfig { + width: w, + words, + bits: words * 64, + bytes: words * 8, + } } } @@ -525,15 +536,9 @@ mod tests { #[test] fn test_xor_associative() { - let a = Fingerprint::<4> { - words: [1, 2, 3, 4], - }; - let b = Fingerprint::<4> { - words: [5, 6, 7, 8], - }; - let c = Fingerprint::<4> { - words: [9, 10, 11, 12], - }; + let a = Fingerprint::<4> { words: [1, 2, 3, 4] }; + let b = Fingerprint::<4> { words: [5, 6, 7, 8] }; + let c = Fingerprint::<4> { words: [9, 10, 11, 12] }; let ab_c = &(&a ^ &b) ^ &c; let a_bc = &a ^ &(&b ^ &c); assert_eq!(ab_c, a_bc); @@ -541,12 +546,8 @@ mod tests { #[test] fn test_hamming_distance() { - let a = Fingerprint::<2> { - words: [0xFF, 0x00], - }; - let b = Fingerprint::<2> { - words: [0x00, 0x00], - }; + let a = Fingerprint::<2> { words: [0xFF, 0x00] }; + let b = Fingerprint::<2> { words: [0x00, 0x00] }; assert_eq!(a.hamming_distance(&b), 8); } @@ -563,9 +564,7 @@ mod tests { let a = Fingerprint::<1> { words: [0xFF] }; assert_eq!(a.popcount(), 8); - let b = Fingerprint::<2> { - words: [0xFF, 0xFF], - }; + let b = Fingerprint::<2> { words: [0xFF, 0xFF] }; assert_eq!(b.popcount(), 16); } @@ -598,12 +597,8 @@ mod tests { #[test] fn test_xor_assign() { - let a = Fingerprint::<2> { - words: [0xFF, 0x00], - }; - let b = Fingerprint::<2> { - words: [0x0F, 0xF0], - }; + let a = Fingerprint::<2> { words: [0xFF, 0x00] }; + let b = Fingerprint::<2> { words: [0x0F, 0xF0] }; let mut c = a.clone(); c ^= &b; assert_eq!(c, &a ^ &b); diff --git a/src/hpc/framebuffer.rs b/src/hpc/framebuffer.rs index 69c60194..d90255e2 100644 --- a/src/hpc/framebuffer.rs +++ b/src/hpc/framebuffer.rs @@ -48,9 +48,9 @@ impl PaletteTier { /// Auto-detect from the active SIMD lane width. pub fn detect() -> Self { match PREFERRED_F32_LANES { - 16 => Self::Full16, // AVX-512 / AMX - 8 => Self::Mid8, // AVX2 - _ => Self::Low4, // NEON (4), scalar (≤4) + 16 => Self::Full16, // AVX-512 / AMX + 8 => Self::Mid8, // AVX2 + _ => Self::Low4, // NEON (4), scalar (≤4) } } @@ -189,10 +189,18 @@ impl Framebuffer { if x0 >= 0 && y0 >= 0 && (x0 as usize) < self.width && (y0 as usize) < self.height { self.pixels[y0 as usize * self.width + x0 as usize] = color; } - if x0 == x1 && y0 == y1 { break; } + if x0 == x1 && y0 == y1 { + break; + } let e2 = 2 * err; - if e2 >= dy { err += dy; x0 += sx; } - if e2 <= dx { err += dx; y0 += sy; } + if e2 >= dy { + err += dy; + x0 += sx; + } + if e2 <= dx { + err += dx; + y0 += sy; + } } let (lx, rx) = (x0.min(x1).max(0) as usize, (x0.max(x1) as usize + 1).min(self.width)); let (ly, ry) = (y0.min(y1).max(0) as usize, (y0.max(y1) as usize + 1).min(self.height)); @@ -291,9 +299,7 @@ pub fn build_mipmap_pyramid(fb: &Framebuffer, min_dim: usize) -> Vec<(Vec, u /// works; replace with perspective when q2 has a camera matrix. #[inline] pub fn project_ortho( - pos_x: f32, pos_y: f32, - scale: f32, offset_x: f32, offset_y: f32, - screen_w: usize, screen_h: usize, + pos_x: f32, pos_y: f32, scale: f32, offset_x: f32, offset_y: f32, screen_w: usize, screen_h: usize, ) -> (usize, usize) { let sx = ((pos_x * scale + offset_x) as usize).min(screen_w.saturating_sub(1)); let sy = ((pos_y * scale + offset_y) as usize).min(screen_h.saturating_sub(1)); @@ -307,13 +313,8 @@ use crate::hpc::renderer::RenderFrame; /// `edges` is a list of (source_idx, target_idx) pairs into the frame's /// node arrays. `color_fn` maps node index → palette color. pub fn compose_neo4j( - fb: &mut Framebuffer, - frame: &RenderFrame, - edges: &[(usize, usize)], - scale: f32, - offset: (f32, f32), - node_color: u8, - edge_color: u8, + fb: &mut Framebuffer, frame: &RenderFrame, edges: &[(usize, usize)], scale: f32, offset: (f32, f32), + node_color: u8, edge_color: u8, ) { fb.clear(); let w = fb.width; @@ -321,35 +322,26 @@ pub fn compose_neo4j( // Edges first (so nodes overdraw on top). for &(src, tgt) in edges { - if src >= frame.len || tgt >= frame.len { continue; } - let (sx0, sy0) = project_ortho( - frame.positions[src * 3], frame.positions[src * 3 + 1], - scale, offset.0, offset.1, w, h, - ); - let (sx1, sy1) = project_ortho( - frame.positions[tgt * 3], frame.positions[tgt * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + if src >= frame.len || tgt >= frame.len { + continue; + } + let (sx0, sy0) = + project_ortho(frame.positions[src * 3], frame.positions[src * 3 + 1], scale, offset.0, offset.1, w, h); + let (sx1, sy1) = + project_ortho(frame.positions[tgt * 3], frame.positions[tgt * 3 + 1], scale, offset.0, offset.1, w, h); fb.draw_line(sx0 as i32, sy0 as i32, sx1 as i32, sy1 as i32, edge_color); } // Nodes as dot sprites. for i in 0..frame.len { - let (sx, sy) = project_ortho( - frame.positions[i * 3], frame.positions[i * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + let (sx, sy) = + project_ortho(frame.positions[i * 3], frame.positions[i * 3 + 1], scale, offset.0, offset.1, w, h); fb.plot_dot(sx, sy, node_color); } } /// Compose an MRI density heatmap view. -pub fn compose_mri( - fb: &mut Framebuffer, - frame: &RenderFrame, - scale: f32, - offset: (f32, f32), -) { +pub fn compose_mri(fb: &mut Framebuffer, frame: &RenderFrame, scale: f32, offset: (f32, f32)) { fb.clear(); let w = fb.width; let h = fb.height; @@ -357,10 +349,8 @@ pub fn compose_mri( let mut xs = Vec::with_capacity(frame.len); let mut ys = Vec::with_capacity(frame.len); for i in 0..frame.len { - let (sx, sy) = project_ortho( - frame.positions[i * 3], frame.positions[i * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + let (sx, sy) = + project_ortho(frame.positions[i * 3], frame.positions[i * 3 + 1], scale, offset.0, offset.1, w, h); xs.push(sx); ys.push(sy); } @@ -377,8 +367,8 @@ mod tests { let tier = PaletteTier::detect(); match PREFERRED_F32_LANES { 16 => assert_eq!(tier, PaletteTier::Full16), - 8 => assert_eq!(tier, PaletteTier::Mid8), - _ => assert_eq!(tier, PaletteTier::Low4), + 8 => assert_eq!(tier, PaletteTier::Mid8), + _ => assert_eq!(tier, PaletteTier::Low4), } } @@ -495,8 +485,12 @@ mod tests { let mut frame = RenderFrame::with_capacity(16); // Two nodes frame.len = 2; - frame.positions[0] = 10.0; frame.positions[1] = 10.0; frame.positions[2] = 0.0; - frame.positions[3] = 50.0; frame.positions[4] = 50.0; frame.positions[5] = 0.0; + frame.positions[0] = 10.0; + frame.positions[1] = 10.0; + frame.positions[2] = 0.0; + frame.positions[3] = 50.0; + frame.positions[4] = 50.0; + frame.positions[5] = 0.0; let edges = vec![(0, 1)]; compose_neo4j(&mut fb, &frame, &edges, 1.0, (0.0, 0.0), 5, 2); // Node 0 should have a dot around (10, 10). @@ -576,8 +570,8 @@ impl WobbleState { if speed > self.inject_threshold { // Perpendicular to velocity direction → organic wobble let norm = speed.recip(); - self.displace[i * 2] += -vy * norm * self.amplitude; - self.displace[i * 2 + 1] += vx * norm * self.amplitude; + self.displace[i * 2] += -vy * norm * self.amplitude; + self.displace[i * 2 + 1] += vx * norm * self.amplitude; } } // Decay all @@ -591,10 +585,7 @@ impl WobbleState { pub fn adjust(&self, sx: usize, sy: usize, node_idx: usize) -> (usize, usize) { let dx = self.displace.get(node_idx * 2).copied().unwrap_or(0.0); let dy = self.displace.get(node_idx * 2 + 1).copied().unwrap_or(0.0); - ( - (sx as f32 + dx).max(0.0) as usize, - (sy as f32 + dy).max(0.0) as usize, - ) + ((sx as f32 + dx).max(0.0) as usize, (sy as f32 + dy).max(0.0) as usize) } } @@ -665,7 +656,7 @@ pub type Glyph = [u8; 5]; /// Missing chars render as a filled block. pub static GLYPH_ATLAS: [Glyph; 128] = { let mut atlas = [[0x7Fu8; 5]; 128]; // default = filled block - // Space + // Space atlas[b' ' as usize] = [0, 0, 0, 0, 0]; // Digits 0-9 atlas[b'0' as usize] = [0x3E, 0x51, 0x49, 0x45, 0x3E]; @@ -787,14 +778,8 @@ impl FlybyCache { /// one full loop over `n_frames`. Scale determines the orbital radius /// in world units; zoom_range controls the min/max camera zoom. pub fn prerender( - fb_template: &Framebuffer, - frame: &RenderFrame, - edges: &[(usize, usize)], - n_frames: usize, - orbit_radius: f32, - zoom_range: (f32, f32), - node_color: u8, - edge_color: u8, + fb_template: &Framebuffer, frame: &RenderFrame, edges: &[(usize, usize)], n_frames: usize, orbit_radius: f32, + zoom_range: (f32, f32), node_color: u8, edge_color: u8, ) -> Self { let mut frames = Vec::with_capacity(n_frames); let w = fb_template.width; @@ -812,15 +797,29 @@ impl FlybyCache { let mut fb = Framebuffer::with_tier(w, h, tier); compose_neo4j( - &mut fb, frame, edges, - cam_zoom, (-cam_x * cam_zoom + w as f32 / 2.0, - -cam_y * cam_zoom + h as f32 / 2.0), - node_color, edge_color, + &mut fb, + frame, + edges, + cam_zoom, + (-cam_x * cam_zoom + w as f32 / 2.0, -cam_y * cam_zoom + h as f32 / 2.0), + node_color, + edge_color, ); let (packed, bpp) = fb.pack(); - frames.push(FlybyFrame { packed, bpp, cam_x, cam_y, cam_zoom }); + frames.push(FlybyFrame { + packed, + bpp, + cam_x, + cam_y, + cam_zoom, + }); + } + Self { + frames, + cursor: 0, + width: w, + height: h, } - Self { frames, cursor: 0, width: w, height: h } } /// Advance the cursor and return the next keyframe (looping). @@ -854,10 +853,14 @@ impl FlybyCache { /// Frame count. /// Frame count. - pub fn len(&self) -> usize { self.frames.len() } + pub fn len(&self) -> usize { + self.frames.len() + } /// True when no keyframes have been pre-rendered. - pub fn is_empty(&self) -> bool { self.frames.is_empty() } + pub fn is_empty(&self) -> bool { + self.frames.is_empty() + } } // ───────────────────────────────────────────────────────────────────── @@ -866,17 +869,8 @@ impl FlybyCache { /// Full Neo4j-style compose with wobble, neuron fire, and labels. pub fn compose_neo4j_full( - fb: &mut Framebuffer, - frame: &RenderFrame, - edges: &[(usize, usize)], - scale: f32, - offset: (f32, f32), - wobble: &WobbleState, - fire: &FireState, - labels: &[&str], - node_base_color: u8, - edge_color: u8, - label_color: u8, + fb: &mut Framebuffer, frame: &RenderFrame, edges: &[(usize, usize)], scale: f32, offset: (f32, f32), + wobble: &WobbleState, fire: &FireState, labels: &[&str], node_base_color: u8, edge_color: u8, label_color: u8, ) { fb.clear(); let w = fb.width; @@ -885,15 +879,13 @@ pub fn compose_neo4j_full( // 1. Edges (drawn first so nodes overdraw). for &(src, tgt) in edges { - if src >= frame.len || tgt >= frame.len { continue; } - let (sx0, sy0) = project_ortho( - frame.positions[src * 3], frame.positions[src * 3 + 1], - scale, offset.0, offset.1, w, h, - ); - let (sx1, sy1) = project_ortho( - frame.positions[tgt * 3], frame.positions[tgt * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + if src >= frame.len || tgt >= frame.len { + continue; + } + let (sx0, sy0) = + project_ortho(frame.positions[src * 3], frame.positions[src * 3 + 1], scale, offset.0, offset.1, w, h); + let (sx1, sy1) = + project_ortho(frame.positions[tgt * 3], frame.positions[tgt * 3 + 1], scale, offset.0, offset.1, w, h); let (wx0, wy0) = wobble.adjust(sx0, sy0, src); let (wx1, wy1) = wobble.adjust(sx1, sy1, tgt); fb.draw_line(wx0 as i32, wy0 as i32, wx1 as i32, wy1 as i32, edge_color); @@ -901,10 +893,8 @@ pub fn compose_neo4j_full( // 2. Nodes as dot sprites with fire boost. for i in 0..frame.len { - let (sx, sy) = project_ortho( - frame.positions[i * 3], frame.positions[i * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + let (sx, sy) = + project_ortho(frame.positions[i * 3], frame.positions[i * 3 + 1], scale, offset.0, offset.1, w, h); let (wx, wy) = wobble.adjust(sx, sy, i); let boost = fire.color_boost(i, pal_max); let color = (node_base_color + boost).min(pal_max); @@ -913,11 +903,11 @@ pub fn compose_neo4j_full( // 3. Labels (drawn last so text is on top). for (i, &label) in labels.iter().enumerate().take(frame.len) { - if label.is_empty() { continue; } - let (sx, sy) = project_ortho( - frame.positions[i * 3], frame.positions[i * 3 + 1], - scale, offset.0, offset.1, w, h, - ); + if label.is_empty() { + continue; + } + let (sx, sy) = + project_ortho(frame.positions[i * 3], frame.positions[i * 3 + 1], scale, offset.0, offset.1, w, h); let (wx, wy) = wobble.adjust(sx, sy, i); let label_y = wy + fb.tier.sprite_size() / 2 + 1; fb.draw_label(wx.saturating_sub(label.len() * 3), label_y, label, label_color); @@ -985,13 +975,13 @@ mod visual_tests { fn flyby_cache_loops_seamlessly() { let mut frame = RenderFrame::with_capacity(16); frame.len = 2; - frame.positions[0] = 10.0; frame.positions[1] = 10.0; - frame.positions[3] = 20.0; frame.positions[4] = 20.0; + frame.positions[0] = 10.0; + frame.positions[1] = 10.0; + frame.positions[3] = 20.0; + frame.positions[4] = 20.0; let edges = vec![(0, 1)]; let fb_template = Framebuffer::with_tier(64, 64, PaletteTier::Full16); - let mut cache = FlybyCache::prerender( - &fb_template, &frame, &edges, 8, 10.0, (0.5, 2.0), 5, 2, - ); + let mut cache = FlybyCache::prerender(&fb_template, &frame, &edges, 8, 10.0, (0.5, 2.0), 5, 2); assert_eq!(cache.len(), 8); // Play through more than one loop — should not panic. for _ in 0..20 { @@ -1005,11 +995,10 @@ mod visual_tests { fn flyby_seek_nearest_finds_closest_frame() { let mut frame = RenderFrame::with_capacity(16); frame.len = 1; - frame.positions[0] = 32.0; frame.positions[1] = 32.0; + frame.positions[0] = 32.0; + frame.positions[1] = 32.0; let fb_template = Framebuffer::with_tier(64, 64, PaletteTier::Full16); - let mut cache = FlybyCache::prerender( - &fb_template, &frame, &[], 16, 10.0, (1.0, 1.0), 5, 2, - ); + let mut cache = FlybyCache::prerender(&fb_template, &frame, &[], 16, 10.0, (1.0, 1.0), 5, 2); cache.seek_nearest(32.0, 32.0); let f = &cache.frames[cache.cursor]; let dx = f.cam_x - 32.0; @@ -1022,17 +1011,16 @@ mod visual_tests { let mut fb = Framebuffer::with_tier(128, 128, PaletteTier::Full16); let mut frame = RenderFrame::with_capacity(16); frame.len = 2; - frame.positions[0] = 30.0; frame.positions[1] = 30.0; - frame.positions[3] = 90.0; frame.positions[4] = 90.0; + frame.positions[0] = 30.0; + frame.positions[1] = 30.0; + frame.positions[3] = 90.0; + frame.positions[4] = 90.0; let edges = vec![(0, 1)]; let wobble = WobbleState::new(16); let mut fire = FireState::new(16); fire.fire(0, 255); let labels = vec!["NODE0", "NODE1"]; - compose_neo4j_full( - &mut fb, &frame, &edges, 1.0, (0.0, 0.0), - &wobble, &fire, &labels, 3, 1, 7, - ); + compose_neo4j_full(&mut fb, &frame, &edges, 1.0, (0.0, 0.0), &wobble, &fire, &labels, 3, 1, 7); // Node 0 should be brighter (fire boost) than base color 3. let node0_pixel = fb.pixels[30 * 128 + 30]; assert!(node0_pixel >= 3, "node0 should have at least base color"); @@ -1059,11 +1047,7 @@ mod visual_tests { /// 3×3 box-blur diffusion: each pixel = average of itself + 8 neighbors. /// In-place via double buffer (src → dst, then swap pointers). /// Palette-safe: result is clamped to [0, max_palette]. -pub fn diffuse_step( - src: &[u8], dst: &mut [u8], - width: usize, height: usize, - max_palette: u8, -) { +pub fn diffuse_step(src: &[u8], dst: &mut [u8], width: usize, height: usize, max_palette: u8) { for y in 0..height { for x in 0..width { let mut sum: u16 = 0; @@ -1141,7 +1125,9 @@ impl PyramidShader { /// Inject heat at L1 coordinates (0..64, 0..64). pub fn inject(&mut self, x: usize, y: usize, intensity: u8) { if x < 64 && y < 64 { - self.l1[y * 64 + x] = self.l1[y * 64 + x].saturating_add(intensity).min(self.palette_max); + self.l1[y * 64 + x] = self.l1[y * 64 + x] + .saturating_add(intensity) + .min(self.palette_max); } } @@ -1169,19 +1155,19 @@ impl PyramidShader { // 2. Cascade: L1 upscales into L2, L2 into L3, L3 into L4. // Additive blend (saturating) so existing diffusion + upscaled signal combine. - let (up1, _, _) = upscale_2x(&self.l1, 64, 64); // 128² - let (up1b, _, _) = upscale_2x(&up1, 128, 128); // 256² + let (up1, _, _) = upscale_2x(&self.l1, 64, 64); // 128² + let (up1b, _, _) = upscale_2x(&up1, 128, 128); // 256² for (dst, src) in self.l2.iter_mut().zip(up1b.iter()) { *dst = dst.saturating_add(*src).min(self.palette_max); } - let (up2, _, _) = upscale_2x(&self.l2, 256, 256); // 512² - let (up2b, _, _) = upscale_2x(&up2, 512, 512); // 1024² + let (up2, _, _) = upscale_2x(&self.l2, 256, 256); // 512² + let (up2b, _, _) = upscale_2x(&up2, 512, 512); // 1024² for (dst, src) in self.l3.iter_mut().zip(up2b.iter()) { *dst = dst.saturating_add(*src).min(self.palette_max); } - let (up3, _, _) = upscale_2x(&self.l3, 1024, 1024); // 2048² + let (up3, _, _) = upscale_2x(&self.l3, 1024, 1024); // 2048² for (dst, src) in self.l4.iter_mut().zip(up3.iter()) { *dst = dst.saturating_add(*src).min(self.palette_max); } @@ -1223,10 +1209,8 @@ impl PyramidShader { /// Nearest-neighbor scale-blit from src (src_w × src_h) into a region /// of the framebuffer at (dst_x, dst_y) with size (dst_w × dst_h). fn blit_scaled( - src: &[u8], src_w: usize, src_h: usize, - fb: &mut Framebuffer, - dst_x: usize, dst_y: usize, - dst_w: usize, dst_h: usize, + src: &[u8], src_w: usize, src_h: usize, fb: &mut Framebuffer, dst_x: usize, dst_y: usize, dst_w: usize, + dst_h: usize, ) { for dy in 0..dst_h { let sy = (dy * src_h) / dst_h; @@ -1306,10 +1290,10 @@ mod pyramid_tests { // Center should have decreased (averaged with zero neighbors). assert!(dst[8 * 16 + 8] < 15); // At least one neighbor should be nonzero. - let neighbor_sum: u16 = [ - dst[7 * 16 + 8], dst[9 * 16 + 8], - dst[8 * 16 + 7], dst[8 * 16 + 9], - ].iter().map(|&v| v as u16).sum(); + let neighbor_sum: u16 = [dst[7 * 16 + 8], dst[9 * 16 + 8], dst[8 * 16 + 7], dst[8 * 16 + 9]] + .iter() + .map(|&v| v as u16) + .sum(); assert!(neighbor_sum > 0, "diffusion should spread to neighbors"); } } diff --git a/src/hpc/gguf.rs b/src/hpc/gguf.rs index d64a4798..9644e8a7 100644 --- a/src/hpc/gguf.rs +++ b/src/hpc/gguf.rs @@ -90,10 +90,10 @@ impl GgmlType { Self::F32 => 4, Self::F16 | Self::BF16 => 2, Self::F64 => 8, - Self::Q4_0 => 18, // 2 (scale) + 32/2 (nibbles) = 18 - Self::Q4_1 => 20, // 2 (scale) + 2 (min) + 32/2 = 20 - Self::Q8_0 => 34, // 2 (scale) + 32 (int8s) = 34 - Self::Q4_K => 144, // Complex block structure + Self::Q4_0 => 18, // 2 (scale) + 32/2 (nibbles) = 18 + Self::Q4_1 => 20, // 2 (scale) + 2 (min) + 32/2 = 20 + Self::Q8_0 => 34, // 2 (scale) + 32 (int8s) = 34 + Self::Q4_K => 144, // Complex block structure _ => 0, } } @@ -168,7 +168,12 @@ pub fn read_gguf_header(reader: &mut R) -> Result(reader: &mut R) -> Result( - reader: &mut R, - gguf: &GgufFile, - tensor: &TensorInfo, + reader: &mut R, gguf: &GgufFile, tensor: &TensorInfo, ) -> Result, String> { let abs_offset = gguf.tensor_data_offset + tensor.offset; - reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; + reader + .seek(SeekFrom::Start(abs_offset)) + .map_err(|e| e.to_string())?; let n_elements = tensor.element_count() as usize; @@ -199,14 +204,16 @@ pub fn read_tensor_f32( GgmlType::F32 => { let mut buf = vec![0u8; n_elements * 4]; reader.read_exact(&mut buf).map_err(|e| e.to_string())?; - Ok(buf.chunks_exact(4) + Ok(buf + .chunks_exact(4) .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) .collect()) } GgmlType::F16 => { let mut buf = vec![0u8; n_elements * 2]; reader.read_exact(&mut buf).map_err(|e| e.to_string())?; - Ok(buf.chunks_exact(2) + Ok(buf + .chunks_exact(2) .map(|c| { let bits = u16::from_le_bytes([c[0], c[1]]); f16_to_f32(bits) @@ -218,25 +225,15 @@ pub fn read_tensor_f32( reader.read_exact(&mut buf).map_err(|e| e.to_string())?; // Reinterpret u8 pairs as BF16 (same repr) and batch-convert via quantized.rs // SAFETY: BF16 is #[repr(transparent)] over u16, same layout as [u8; 2] LE pairs. - let bf16_slice: &[super::quantized::BF16] = unsafe { - std::slice::from_raw_parts( - buf.as_ptr() as *const super::quantized::BF16, - n_elements, - ) - }; + let bf16_slice: &[super::quantized::BF16] = + unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const super::quantized::BF16, n_elements) }; let mut result = vec![0.0f32; n_elements]; super::quantized::bf16_to_f32_slice(bf16_slice, &mut result); Ok(result) } - GgmlType::Q8_0 => { - dequantize_q8_0(reader, n_elements) - } - GgmlType::Q4_0 => { - dequantize_q4_0(reader, n_elements) - } - GgmlType::Q4_K => { - dequantize_q4_k(reader, n_elements) - } + GgmlType::Q8_0 => dequantize_q8_0(reader, n_elements), + GgmlType::Q4_0 => dequantize_q4_0(reader, n_elements), + GgmlType::Q4_K => dequantize_q4_k(reader, n_elements), other => Err(format!("Unsupported dtype for dequantization: {:?}", other)), } } @@ -248,7 +245,8 @@ pub fn find_tensor<'a>(gguf: &'a GgufFile, pattern: &str) -> Option<&'a TensorIn /// List all tensor names and shapes. pub fn list_tensors(gguf: &GgufFile) -> Vec<(String, Vec, GgmlType)> { - gguf.tensors.iter() + gguf.tensors + .iter() .map(|t| (t.name.clone(), t.dimensions.clone(), t.dtype)) .collect() } @@ -279,16 +277,44 @@ fn read_string(r: &mut R) -> Result { fn read_metadata_value(r: &mut R, value_type: u32) -> Result { match value_type { - 0 => { let mut b = [0u8; 1]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok(b[0].to_string()) } // u8 - 1 => { let mut b = [0u8; 1]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok((b[0] as i8).to_string()) } // i8 - 2 => { let mut b = [0u8; 2]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok(u16::from_le_bytes(b).to_string()) } // u16 - 3 => { let mut b = [0u8; 2]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok(i16::from_le_bytes(b).to_string()) } // i16 + 0 => { + let mut b = [0u8; 1]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok(b[0].to_string()) + } // u8 + 1 => { + let mut b = [0u8; 1]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok((b[0] as i8).to_string()) + } // i8 + 2 => { + let mut b = [0u8; 2]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok(u16::from_le_bytes(b).to_string()) + } // u16 + 3 => { + let mut b = [0u8; 2]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok(i16::from_le_bytes(b).to_string()) + } // i16 4 => Ok(read_u32(r)?.to_string()), // u32 - 5 => { let v = read_u32(r)?; Ok((v as i32).to_string()) } // i32 - 6 => { let mut b = [0u8; 4]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok(f32::from_le_bytes(b).to_string()) } // f32 - 7 => { let mut b = [0u8; 1]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok((b[0] != 0).to_string()) } // bool - 8 => read_string(r), // string - 9 => { // array + 5 => { + let v = read_u32(r)?; + Ok((v as i32).to_string()) + } // i32 + 6 => { + let mut b = [0u8; 4]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok(f32::from_le_bytes(b).to_string()) + } // f32 + 7 => { + let mut b = [0u8; 1]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok((b[0] != 0).to_string()) + } // bool + 8 => read_string(r), // string + 9 => { + // array let elem_type = read_u32(r)?; let count = read_u64(r)?; // Skip array elements (we don't need them for tensor loading) @@ -298,8 +324,15 @@ fn read_metadata_value(r: &mut R, value_type: u32) -> Result Ok(read_u64(r)?.to_string()), // u64 - 11 => { let v = read_u64(r)?; Ok((v as i64).to_string()) } // i64 - 12 => { let mut b = [0u8; 8]; r.read_exact(&mut b).map_err(|e| e.to_string())?; Ok(f64::from_le_bytes(b).to_string()) } // f64 + 11 => { + let v = read_u64(r)?; + Ok((v as i64).to_string()) + } // i64 + 12 => { + let mut b = [0u8; 8]; + r.read_exact(&mut b).map_err(|e| e.to_string())?; + Ok(f64::from_le_bytes(b).to_string()) + } // f64 _ => Err(format!("Unknown metadata value type: {}", value_type)), } } @@ -571,7 +604,9 @@ mod tests { append_tensor_info(&mut buf, "blk.0.attn_q.weight", &[4096, 4096], 8, 0); append_tensor_info(&mut buf, "blk.0.attn_k.weight", &[4096, 1024], 8, 4096 * 4096 * 34 / 32); - while buf.len() % 32 != 0 { buf.push(0); } + while buf.len() % 32 != 0 { + buf.push(0); + } let mut cursor = Cursor::new(&buf); let gguf = read_gguf_header(&mut cursor).unwrap(); diff --git a/src/hpc/gguf_indexer.rs b/src/hpc/gguf_indexer.rs index e3edec99..a453cf4d 100644 --- a/src/hpc/gguf_indexer.rs +++ b/src/hpc/gguf_indexer.rs @@ -19,7 +19,7 @@ //! Supports: F32, F16, BF16, Q8_0, Q4_0, Q4_K (via gguf.rs dequant). use super::bgz17_bridge::Base17; -use super::gguf::{self, GgufFile, TensorInfo, GgmlType}; +use super::gguf::{self, GgmlType, GgufFile, TensorInfo}; use std::io::{Read, Seek, SeekFrom, Write}; // ============================================================================ @@ -69,19 +69,30 @@ pub fn classify_tensor(name: &str, dims: &[u64]) -> LayerType { } // Attention projections - if name.contains("attn") || name.contains("self_attn") - || name.contains("attn_q") || name.contains("attn_k") - || name.contains("attn_v") || name.contains("attn_output") - || name.contains("q_proj") || name.contains("k_proj") - || name.contains("v_proj") || name.contains("o_proj") + if name.contains("attn") + || name.contains("self_attn") + || name.contains("attn_q") + || name.contains("attn_k") + || name.contains("attn_v") + || name.contains("attn_output") + || name.contains("q_proj") + || name.contains("k_proj") + || name.contains("v_proj") + || name.contains("o_proj") { return LayerType::Attention; } // Feed-forward - if name.contains("ffn") || name.contains("mlp") || name.contains("fc1") - || name.contains("fc2") || name.contains("gate") || name.contains("up_proj") - || name.contains("down_proj") || name.contains("w1") || name.contains("w2") + if name.contains("ffn") + || name.contains("mlp") + || name.contains("fc1") + || name.contains("fc2") + || name.contains("gate") + || name.contains("up_proj") + || name.contains("down_proj") + || name.contains("w1") + || name.contains("w2") || name.contains("w3") { return LayerType::FeedForward; @@ -210,10 +221,7 @@ fn gather_bf16_x8(buf: &[u16], offsets: &[usize; 8]) -> crate::simd::F64x8 { /// /// Memory: 17 × F64x8 accumulators on stack = 17 × 64 = 1088 bytes. pub fn project_8rows_bf16_simd( - buf: &[u16], - row_starts: &[usize; 8], - n_cols: usize, - octave_stride: usize, + buf: &[u16], row_starts: &[usize; 8], n_cols: usize, octave_stride: usize, ) -> [Base17; 8] { use crate::simd::F64x8; @@ -230,10 +238,14 @@ pub fn project_8rows_bf16_simd( let col = octave * BASE_DIM + GOLDEN_POS[bi] as usize; if col < n_cols { let offsets: [usize; 8] = [ - row_starts[0] + col, row_starts[1] + col, - row_starts[2] + col, row_starts[3] + col, - row_starts[4] + col, row_starts[5] + col, - row_starts[6] + col, row_starts[7] + col, + row_starts[0] + col, + row_starts[1] + col, + row_starts[2] + col, + row_starts[3] + col, + row_starts[4] + col, + row_starts[5] + col, + row_starts[6] + col, + row_starts[7] + col, ]; sums[bi] += gather_bf16_x8(buf, &offsets); counts[bi] += 1; @@ -259,10 +271,14 @@ pub fn project_8rows_bf16_simd( } [ - Base17 { dims: dims_x8[0] }, Base17 { dims: dims_x8[1] }, - Base17 { dims: dims_x8[2] }, Base17 { dims: dims_x8[3] }, - Base17 { dims: dims_x8[4] }, Base17 { dims: dims_x8[5] }, - Base17 { dims: dims_x8[6] }, Base17 { dims: dims_x8[7] }, + Base17 { dims: dims_x8[0] }, + Base17 { dims: dims_x8[1] }, + Base17 { dims: dims_x8[2] }, + Base17 { dims: dims_x8[3] }, + Base17 { dims: dims_x8[4] }, + Base17 { dims: dims_x8[5] }, + Base17 { dims: dims_x8[6] }, + Base17 { dims: dims_x8[7] }, ] } @@ -304,12 +320,7 @@ pub fn project_1row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 { /// /// Per sampled octave: 17 positions × 8 bf16_to_f64 gathers → 17 vaddpd. /// For 5120-col rows at stride=16: 19 octaves × 17 = 323 vaddpd per 8-row batch. -pub fn project_tensor_bf16_simd( - buf: &[u16], - n_rows: usize, - n_cols: usize, - octave_stride: usize, -) -> Vec { +pub fn project_tensor_bf16_simd(buf: &[u16], n_rows: usize, n_cols: usize, octave_stride: usize) -> Vec { let mut result = Vec::with_capacity(n_rows); let full_batches = n_rows / 8; @@ -317,10 +328,14 @@ pub fn project_tensor_bf16_simd( for batch in 0..full_batches { let base_row = batch * 8; let row_starts: [usize; 8] = [ - (base_row + 0) * n_cols, (base_row + 1) * n_cols, - (base_row + 2) * n_cols, (base_row + 3) * n_cols, - (base_row + 4) * n_cols, (base_row + 5) * n_cols, - (base_row + 6) * n_cols, (base_row + 7) * n_cols, + (base_row + 0) * n_cols, + (base_row + 1) * n_cols, + (base_row + 2) * n_cols, + (base_row + 3) * n_cols, + (base_row + 4) * n_cols, + (base_row + 5) * n_cols, + (base_row + 6) * n_cols, + (base_row + 7) * n_cols, ]; let b17s = project_8rows_bf16_simd(buf, &row_starts, n_cols, octave_stride); result.extend_from_slice(&b17s); @@ -339,9 +354,7 @@ pub fn project_tensor_bf16_simd( /// Helper: tensor dimensions → (rows, cols) without needing data. fn tensor_to_rows_dims(dims: &[u64], layer_type: &LayerType) -> (usize, usize) { match layer_type { - LayerType::Conv2D if dims.len() == 4 => { - (dims[0] as usize, (dims[1] * dims[2] * dims[3]) as usize) - } + LayerType::Conv2D if dims.len() == 4 => (dims[0] as usize, (dims[1] * dims[2] * dims[3]) as usize), _ if dims.len() >= 2 => { let rows = dims[0] as usize; let cols: usize = dims[1..].iter().map(|&d| d as usize).product(); @@ -376,10 +389,7 @@ fn layer_type_index(lt: &LayerType) -> usize { /// /// Falls back to f32 path for non-BF16 dtypes. pub fn stream_index_gguf_bf16( - reader: &mut R, - writer: &mut W, - octave_stride: usize, - callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, + reader: &mut R, writer: &mut W, octave_stride: usize, callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, ) -> Result { let header = gguf::read_gguf_header(reader)?; stream_index_gguf_bf16_with_header(reader, writer, &header, octave_stride, callback) @@ -391,17 +401,16 @@ pub fn stream_index_gguf_bf16( /// - `tensor_data_offset`: absolute byte offset where tensor data starts /// - `tensors`: Vec with name, dimensions, dtype, offset (relative to data start) pub fn stream_index_gguf_bf16_with_header( - reader: &mut R, - writer: &mut W, - header: &gguf::GgufFile, - octave_stride: usize, + reader: &mut R, writer: &mut W, header: &gguf::GgufFile, octave_stride: usize, callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, ) -> Result { let mut stats = IndexStats::default(); stats.tensors_total = header.tensors.len(); writer.write_all(b"BGZ7").map_err(|e| e.to_string())?; - writer.write_all(&(header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&(header.tensors.len() as u32).to_le_bytes()) + .map_err(|e| e.to_string())?; // Reusable buffer — capped at 128 MB (64M u16 elements). // Tensors larger than this are read in row batches. @@ -437,7 +446,9 @@ pub fn stream_index_gguf_bf16_with_header( // Seek to tensor start let abs_offset = header.tensor_data_offset + tensor.offset; - reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; + reader + .seek(std::io::SeekFrom::Start(abs_offset)) + .map_err(|e| e.to_string())?; let mut rows: Vec = Vec::with_capacity(n_rows); let mut rows_done: usize = 0; @@ -448,19 +459,13 @@ pub fn stream_index_gguf_bf16_with_header( let batch_elems = batch_n * n_cols; // Read batch bytes into reusable buffer - let byte_slice = unsafe { - std::slice::from_raw_parts_mut( - bf16_buf.as_mut_ptr() as *mut u8, - batch_elems * 2, - ) - }; + let byte_slice = + unsafe { std::slice::from_raw_parts_mut(bf16_buf.as_mut_ptr() as *mut u8, batch_elems * 2) }; reader.read_exact(byte_slice).map_err(|e| e.to_string())?; // Project this batch if octave_stride > 1 { - let batch_b17 = project_tensor_bf16_simd( - &bf16_buf[..batch_elems], batch_n, n_cols, octave_stride - ); + let batch_b17 = project_tensor_bf16_simd(&bf16_buf[..batch_elems], batch_n, n_cols, octave_stride); rows.extend_from_slice(&batch_b17); } else { for r in 0..batch_n { @@ -473,8 +478,12 @@ pub fn stream_index_gguf_bf16_with_header( // Progress for large tensors (every chunk) if is_large && rows_done < n_rows { - eprintln!(" ... {}/{} rows ({:.0}%)", - rows_done, n_rows, rows_done as f64 / n_rows as f64 * 100.0); + eprintln!( + " ... {}/{} rows ({:.0}%)", + rows_done, + n_rows, + rows_done as f64 / n_rows as f64 * 100.0 + ); } } @@ -500,7 +509,9 @@ pub fn stream_index_gguf_bf16_with_header( stats.tensors_indexed += 1; let buf_bytes = chunk_elems as u64 * 2; - if buf_bytes > stats.peak_tensor_bytes { stats.peak_tensor_bytes = buf_bytes; } + if buf_bytes > stats.peak_tensor_bytes { + stats.peak_tensor_bytes = buf_bytes; + } // Shrink buffer if it grew past the cap (shouldn't, but defensive) if bf16_buf.len() > MAX_BUF_ELEMS { @@ -588,14 +599,17 @@ impl CompressedTensor { /// Compression ratio. pub fn ratio(&self) -> f64 { - if self.compressed_bytes() == 0 { return 0.0; } + if self.compressed_bytes() == 0 { + return 0.0; + } self.original_bytes() as f64 / self.compressed_bytes() as f64 } /// Serialize to bytes: [name_len:u32][name][layer_type:u8][n_rows:u32][n_cols:u32][base17 × n_rows] pub fn write_to(&self, w: &mut W) -> Result<(), String> { let name_bytes = self.name.as_bytes(); - w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; + w.write_all(&(name_bytes.len() as u32).to_le_bytes()) + .map_err(|e| e.to_string())?; w.write_all(name_bytes).map_err(|e| e.to_string())?; let lt_byte: u8 = match self.layer_type { @@ -607,8 +621,10 @@ impl CompressedTensor { LayerType::Skip => 5, }; w.write_all(&[lt_byte]).map_err(|e| e.to_string())?; - w.write_all(&(self.n_rows as u32).to_le_bytes()).map_err(|e| e.to_string())?; - w.write_all(&(self.n_cols as u32).to_le_bytes()).map_err(|e| e.to_string())?; + w.write_all(&(self.n_rows as u32).to_le_bytes()) + .map_err(|e| e.to_string())?; + w.write_all(&(self.n_cols as u32).to_le_bytes()) + .map_err(|e| e.to_string())?; for b17 in &self.rows { w.write_all(&b17.to_bytes()).map_err(|e| e.to_string())?; @@ -704,9 +720,7 @@ fn tensor_to_rows(data: &[f32], dims: &[u64], layer_type: &LayerType) -> (usize, let cols: usize = dims[1..].iter().map(|&d| d as usize).product(); (rows, cols) } - _ => { - (1, data.len()) - } + _ => (1, data.len()), } } @@ -728,7 +742,9 @@ pub struct IndexStats { impl IndexStats { pub fn overall_ratio(&self) -> f64 { - if self.compressed_bytes == 0 { return 0.0; } + if self.compressed_bytes == 0 { + return 0.0; + } self.original_bytes as f64 / self.compressed_bytes as f64 } } @@ -739,9 +755,7 @@ impl IndexStats { /// For Llama 4 Scout: largest expert = 5120 × 13824 × 4 = ~270 MB. /// Total RAM: ~300 MB regardless of model size. pub fn stream_index_gguf( - reader: &mut R, - writer: &mut W, - callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, + reader: &mut R, writer: &mut W, callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, ) -> Result { let gguf = gguf::read_gguf_header(reader)?; let mut stats = IndexStats::default(); @@ -749,7 +763,9 @@ pub fn stream_index_gguf( // Write file header: magic + tensor count writer.write_all(b"BGZ7").map_err(|e| e.to_string())?; - writer.write_all(&(gguf.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&(gguf.tensors.len() as u32).to_le_bytes()) + .map_err(|e| e.to_string())?; for tensor in &gguf.tensors { let layer_type = classify_tensor(&tensor.name, &tensor.dimensions); @@ -970,7 +986,9 @@ mod tests { buf.extend_from_slice(&(t2_offset as u64).to_le_bytes()); // Pad to alignment (32 bytes) - while buf.len() % 32 != 0 { buf.push(0); } + while buf.len() % 32 != 0 { + buf.push(0); + } // Tensor 1 data: 64×64 f32 for i in 0..(64 * 64) { @@ -1006,7 +1024,10 @@ mod tests { let path = "/tmp/openchat/openchat-3.5-0106.Q8_0.gguf"; let file = match std::fs::File::open(path) { Ok(f) => f, - Err(_) => { eprintln!("SKIP: {} not found", path); return; } + Err(_) => { + eprintln!("SKIP: {} not found", path); + return; + } }; let input_size = file.metadata().map(|m| m.len()).unwrap_or(0); let mut reader = BufReader::new(file); @@ -1020,10 +1041,10 @@ mod tests { &mut writer, Some(&|name, layer_type, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:50} {:12?} {:>10} → {:>8} ({:.0}×)", - name, layer_type, orig, comp, ratio); + eprintln!(" {:50} {:12?} {:>10} → {:>8} ({:.0}×)", name, layer_type, orig, comp, ratio); }), - ).expect("stream_index_gguf"); + ) + .expect("stream_index_gguf"); drop(writer); let out_size = std::fs::metadata(out_path).map(|m| m.len()).unwrap_or(0); @@ -1032,8 +1053,10 @@ mod tests { eprintln!("=== OpenChat 3.5 Q8_0 → bgz17 Results ==="); eprintln!(" Input: {:.2} GB ({})", input_size as f64 / 1e9, path); eprintln!(" Output: {:.2} MB ({})", out_size as f64 / 1e6, out_path); - eprintln!(" Tensors: {} total, {} indexed, {} skipped", - stats.tensors_total, stats.tensors_indexed, stats.tensors_skipped); + eprintln!( + " Tensors: {} total, {} indexed, {} skipped", + stats.tensors_total, stats.tensors_indexed, stats.tensors_skipped + ); eprintln!(" Original (f32): {:.2} MB", stats.original_bytes as f64 / 1e6); eprintln!(" Compressed: {:.2} MB", stats.compressed_bytes as f64 / 1e6); eprintln!(" Overall ratio: {:.1}×", stats.overall_ratio()); @@ -1045,8 +1068,14 @@ mod tests { let (count, orig, comp) = stats.by_type[i]; if count > 0 { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:<12} {:>3} tensors: {:>10.2} MB → {:>8.2} MB ({:.1}×)", - name, count, orig as f64 / 1e6, comp as f64 / 1e6, ratio); + eprintln!( + " {:<12} {:>3} tensors: {:>10.2} MB → {:>8.2} MB ({:.1}×)", + name, + count, + orig as f64 / 1e6, + comp as f64 / 1e6, + ratio + ); } } @@ -1057,7 +1086,7 @@ mod tests { #[test] #[ignore] // Streams from HuggingFace — requires network + time fn test_stream_index_llama4_scout_from_hf() { - use super::super::http_reader::{HttpRangeReader, resolve_hf_url}; + use super::super::http_reader::{resolve_hf_url, HttpRangeReader}; use std::io::BufWriter; let repo = "unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF"; @@ -1066,7 +1095,10 @@ mod tests { eprintln!("Resolving {} / {} ...", repo, filename); let (url, size) = match resolve_hf_url(repo, filename) { Ok(r) => r, - Err(e) => { eprintln!("SKIP: {}", e); return; } + Err(e) => { + eprintln!("SKIP: {}", e); + return; + } }; eprintln!(" URL resolved, size: {:.2} GB", size as f64 / 1e9); @@ -1082,10 +1114,10 @@ mod tests { &mut writer, Some(&|name, layer_type, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", - name, layer_type, orig, comp, ratio); + eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", name, layer_type, orig, comp, ratio); }), - ).expect("stream_index_gguf"); + ) + .expect("stream_index_gguf"); drop(writer); let out_size = std::fs::metadata(out_path).map(|m| m.len()).unwrap_or(0); @@ -1095,8 +1127,7 @@ mod tests { eprintln!(" Source: {:.2} GB ({})", size as f64 / 1e9, filename); eprintln!(" Output: {:.2} MB ({})", out_size as f64 / 1e6, out_path); eprintln!(" Downloaded: {:.2} GB", reader.bytes_downloaded() as f64 / 1e9); - eprintln!(" Tensors: {} indexed, {} skipped", - stats.tensors_indexed, stats.tensors_skipped); + eprintln!(" Tensors: {} indexed, {} skipped", stats.tensors_indexed, stats.tensors_skipped); eprintln!(" Original (f32): {:.2} GB", stats.original_bytes as f64 / 1e9); eprintln!(" Compressed: {:.2} MB", stats.compressed_bytes as f64 / 1e6); eprintln!(" Ratio: {:.1}×", stats.overall_ratio()); @@ -1107,8 +1138,14 @@ mod tests { let (count, orig, comp) = stats.by_type[i]; if count > 0 { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.1}×)", - name, count, orig as f64 / 1e9, comp as f64 / 1e6, ratio); + eprintln!( + " {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.1}×)", + name, + count, + orig as f64 / 1e9, + comp as f64 / 1e6, + ratio + ); } } @@ -1125,9 +1162,7 @@ mod tests { use std::io::BufWriter; let repo = "unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF"; - let filename = format!( - "BF16/Llama-4-Scout-17B-16E-Instruct-BF16-{:05}-of-00005.gguf", shard - ); + let filename = format!("BF16/Llama-4-Scout-17B-16E-Instruct-BF16-{:05}-of-00005.gguf", shard); let octave_stride: usize = 16; eprintln!("Streaming shard {}/5: {}", shard, filename); @@ -1146,10 +1181,10 @@ mod tests { octave_stride, Some(&|name, layer_type, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", - name, layer_type, orig, comp, ratio); + eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", name, layer_type, orig, comp, ratio); }), - ).expect("stream_index_gguf_bf16"); + ) + .expect("stream_index_gguf_bf16"); drop(writer); let out_size = std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0); @@ -1158,8 +1193,7 @@ mod tests { eprintln!("=== Llama 4 Scout BF16 Shard {}/5 → bgz17 (BF16-direct) ===", shard); eprintln!(" Output: {:.2} MB ({})", out_size as f64 / 1e6, out_path); eprintln!(" Downloaded: {:.2} GB", reader.bytes_downloaded() as f64 / 1e9); - eprintln!(" Tensors: {} indexed, {} skipped", - stats.tensors_indexed, stats.tensors_skipped); + eprintln!(" Tensors: {} indexed, {} skipped", stats.tensors_indexed, stats.tensors_skipped); eprintln!(" Original (f32): {:.2} GB", stats.original_bytes as f64 / 1e9); eprintln!(" Compressed: {:.2} MB", stats.compressed_bytes as f64 / 1e6); eprintln!(" Ratio: {:.1}×", stats.overall_ratio()); @@ -1170,8 +1204,14 @@ mod tests { let (count, orig, comp) = stats.by_type[i]; if count > 0 { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.1}×)", - name, count, orig as f64 / 1e9, comp as f64 / 1e6, ratio); + eprintln!( + " {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.1}×)", + name, + count, + orig as f64 / 1e9, + comp as f64 / 1e6, + ratio + ); } } @@ -1181,19 +1221,29 @@ mod tests { #[test] #[ignore] - fn test_stream_index_llama4_bf16_shard1() { run_llama4_shard(1); } + fn test_stream_index_llama4_bf16_shard1() { + run_llama4_shard(1); + } #[test] #[ignore] - fn test_stream_index_llama4_bf16_shard2() { run_llama4_shard(2); } + fn test_stream_index_llama4_bf16_shard2() { + run_llama4_shard(2); + } #[test] #[ignore] - fn test_stream_index_llama4_bf16_shard3() { run_llama4_shard(3); } + fn test_stream_index_llama4_bf16_shard3() { + run_llama4_shard(3); + } #[test] #[ignore] - fn test_stream_index_llama4_bf16_shard4() { run_llama4_shard(4); } + fn test_stream_index_llama4_bf16_shard4() { + run_llama4_shard(4); + } #[test] #[ignore] - fn test_stream_index_llama4_bf16_shard5() { run_llama4_shard(5); } + fn test_stream_index_llama4_bf16_shard5() { + run_llama4_shard(5); + } // ── BF16-direct optimization tests ── @@ -1215,8 +1265,7 @@ mod tests { for i in 0..17 { let diff = (full.dims[i] as i32 - strided.dims[i] as i32).abs(); - assert!(diff <= 1, "bin {} differs by {}: full={}, strided={}", - i, diff, full.dims[i], strided.dims[i]); + assert!(diff <= 1, "bin {} differs by {}: full={}, strided={}", i, diff, full.dims[i], strided.dims[i]); } } @@ -1224,7 +1273,10 @@ mod tests { fn test_bf16_direct_matches_f32_path() { // Same data in BF16 and f32 should produce identical Base17 let f32_row: Vec = (0..4096).map(|i| (i as f32) * 0.001).collect(); - let bf16_row: Vec = f32_row.iter().map(|&v| (v.to_bits() >> 16) as u16).collect(); + let bf16_row: Vec = f32_row + .iter() + .map(|&v| (v.to_bits() >> 16) as u16) + .collect(); let from_f32 = project_row_to_base17(&f32_row); let from_bf16 = project_row_bf16_direct(&bf16_row); @@ -1248,8 +1300,14 @@ mod tests { for r in 1..n_rows { for bin in 0..BASE_DIM { let diff = (simd_results[0].dims[bin] as i32 - simd_results[r].dims[bin] as i32).abs(); - assert!(diff == 0, "row {} bin {} differs: {} vs {}", - r, bin, simd_results[0].dims[bin], simd_results[r].dims[bin]); + assert!( + diff == 0, + "row {} bin {} differs: {} vs {}", + r, + bin, + simd_results[0].dims[bin], + simd_results[r].dims[bin] + ); } } } @@ -1271,8 +1329,15 @@ mod tests { let scalar = project_1row_bf16_strided(&buf[start..start + n_cols], 16); for bin in 0..BASE_DIM { let diff = (simd_results[r].dims[bin] as i32 - scalar.dims[bin] as i32).abs(); - assert!(diff <= 1, "row {} bin {} simd={} scalar={} diff={}", - r, bin, simd_results[r].dims[bin], scalar.dims[bin], diff); + assert!( + diff <= 1, + "row {} bin {} simd={} scalar={} diff={}", + r, + bin, + simd_results[r].dims[bin], + scalar.dims[bin], + diff + ); } } } @@ -1296,15 +1361,15 @@ mod tests { let repo = "unsloth/Llama-4-Maverick-17B-128E-Instruct-GGUF"; let shards: [(u8, &str, u64); 18] = [ - ( 1, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00001-of-00018.gguf", 46_166_870_240), - ( 2, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00002-of-00018.gguf", 42_949_673_376), - ( 3, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00003-of-00018.gguf", 42_949_673_376), - ( 4, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00004-of-00018.gguf", 42_949_673_376), - ( 5, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00005-of-00018.gguf", 47_943_931_840), - ( 6, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00006-of-00018.gguf", 42_949_673_376), - ( 7, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00007-of-00018.gguf", 42_949_673_376), - ( 8, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00008-of-00018.gguf", 42_949_673_376), - ( 9, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00009-of-00018.gguf", 47_922_960_288), + (1, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00001-of-00018.gguf", 46_166_870_240), + (2, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00002-of-00018.gguf", 42_949_673_376), + (3, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00003-of-00018.gguf", 42_949_673_376), + (4, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00004-of-00018.gguf", 42_949_673_376), + (5, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00005-of-00018.gguf", 47_943_931_840), + (6, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00006-of-00018.gguf", 42_949_673_376), + (7, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00007-of-00018.gguf", 42_949_673_376), + (8, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00008-of-00018.gguf", 42_949_673_376), + (9, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00009-of-00018.gguf", 47_922_960_288), (10, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00010-of-00018.gguf", 42_949_673_376), (11, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00011-of-00018.gguf", 42_949_673_376), (12, "BF16/Llama-4-Maverick-17B-128E-Instruct-BF16-00012-of-00018.gguf", 47_912_433_568), @@ -1340,8 +1405,8 @@ mod tests { eprintln!("━━━ Shard {:02}/18 ━━━", shard_num); - let mut reader = HttpRangeReader::from_hf(repo, filename, 256 * 1024 * 1024) - .expect("failed to resolve HF URL"); + let mut reader = + HttpRangeReader::from_hf(repo, filename, 256 * 1024 * 1024).expect("failed to resolve HF URL"); let out = std::fs::File::create(&out_path).expect("create output"); let mut writer = BufWriter::new(out); @@ -1352,26 +1417,36 @@ mod tests { octave_stride, Some(&|name, layer_type, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", - name, layer_type, orig, comp, ratio); + eprintln!(" {:60} {:12?} {:>12} → {:>8} ({:.0}×)", name, layer_type, orig, comp, ratio); }), - ).unwrap_or_else(|e| panic!("stream_index_gguf_bf16 shard {} failed: {}", shard_num, e)); + ) + .unwrap_or_else(|e| panic!("stream_index_gguf_bf16 shard {} failed: {}", shard_num, e)); drop(writer); let out_size = std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0); - eprintln!(" Shard {:02}: {:.2} GB → {:.2} MB ({:.0}×) peak_buf={:.1} MB", - shard_num, *size as f64 / 1e9, out_size as f64 / 1e6, + eprintln!( + " Shard {:02}: {:.2} GB → {:.2} MB ({:.0}×) peak_buf={:.1} MB", + shard_num, + *size as f64 / 1e9, + out_size as f64 / 1e6, stats.overall_ratio(), - stats.peak_tensor_bytes as f64 / 1e6); + stats.peak_tensor_bytes as f64 / 1e6 + ); let type_names = ["Attention", "FeedForward", "Conv2D", "Norm", "Embedding", "Skip"]; for (j, name) in type_names.iter().enumerate() { let (count, orig, comp) = stats.by_type[j]; if count > 0 { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.0}×)", - name, count, orig as f64 / 1e9, comp as f64 / 1e6, ratio); + eprintln!( + " {:<12} {:>3} tensors: {:>10.2} GB → {:>8.2} MB ({:.0}×)", + name, + count, + orig as f64 / 1e9, + comp as f64 / 1e6, + ratio + ); grand_by_type[j].0 += count; grand_by_type[j].1 += orig; grand_by_type[j].2 += comp; @@ -1410,8 +1485,7 @@ mod tests { eprintln!(" Source (BF16): {:>10.2} GB", grand_total_source as f64 / 1e9); eprintln!(" Original (f32): {:>10.2} GB", grand_total_original as f64 / 1e9); eprintln!(" Compressed: {:>10.2} MB", grand_total_compressed as f64 / 1e6); - eprintln!(" Overall ratio: {:>10.0}×", - grand_total_original as f64 / grand_total_compressed.max(1) as f64); + eprintln!(" Overall ratio: {:>10.0}×", grand_total_original as f64 / grand_total_compressed.max(1) as f64); eprintln!(" Tensors indexed: {}", grand_total_tensors); let type_names = ["Attention", "FeedForward", "Conv2D", "Norm", "Embedding", "Skip"]; @@ -1419,8 +1493,14 @@ mod tests { let (count, orig, comp) = grand_by_type[j]; if count > 0 { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; - eprintln!(" {:<12} {:>4} tensors: {:>10.2} GB → {:>8.2} MB ({:.0}×)", - name, count, orig as f64 / 1e9, comp as f64 / 1e6, ratio); + eprintln!( + " {:<12} {:>4} tensors: {:>10.2} GB → {:>8.2} MB ({:.0}×)", + name, + count, + orig as f64 / 1e9, + comp as f64 / 1e6, + ratio + ); } } eprintln!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); diff --git a/src/hpc/gpt2/api.rs b/src/hpc/gpt2/api.rs index a5a2a564..967b5ffe 100644 --- a/src/hpc/gpt2/api.rs +++ b/src/hpc/gpt2/api.rs @@ -5,9 +5,9 @@ //! - `/v1/embeddings` — token embeddings via wte //! - `/v1/models` — model info -use crate::hpc::models::api_types::*; use super::inference::{GeneratedToken, Gpt2Engine}; use super::weights::*; +use crate::hpc::models::api_types::*; /// Stateless API wrapper around Gpt2Engine. pub struct Gpt2Api { @@ -17,7 +17,10 @@ pub struct Gpt2Api { impl Gpt2Api { pub fn new(weights: Gpt2Weights) -> Self { - Self { engine: Gpt2Engine::new(weights), request_counter: 0 } + Self { + engine: Gpt2Engine::new(weights), + request_counter: 0, + } } /// `/v1/completions` @@ -35,14 +38,20 @@ impl Gpt2Api { FinishReason::Length }; - let text = generated.iter().map(|t| format!("[{}]", t.token_id)).collect::(); - let logprobs: Vec = generated.iter().map(|t| LogprobInfo { - token: format!("{}", t.token_id), - token_id: t.token_id, - logprob: t.logprob, - bytes: None, - top_logprobs: Vec::new(), - }).collect(); + let text = generated + .iter() + .map(|t| format!("[{}]", t.token_id)) + .collect::(); + let logprobs: Vec = generated + .iter() + .map(|t| LogprobInfo { + token: format!("{}", t.token_id), + token_id: t.token_id, + logprob: t.logprob, + bytes: None, + top_logprobs: Vec::new(), + }) + .collect(); let use_logprobs = req.logprobs.is_some(); @@ -71,19 +80,27 @@ impl Gpt2Api { _ => req.input_tokens.clone().unwrap_or_default(), }; - let data: Vec = token_ids.iter().enumerate().map(|(idx, &tid)| { - let offset = tid as usize * EMBED_DIM; - let mut emb = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec(); - if let Some(dim) = req.dimensions { - emb.truncate(dim); - } - EmbeddingData::new(idx, emb) - }).collect(); + let data: Vec = token_ids + .iter() + .enumerate() + .map(|(idx, &tid)| { + let offset = tid as usize * EMBED_DIM; + let mut emb = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec(); + if let Some(dim) = req.dimensions { + emb.truncate(dim); + } + EmbeddingData::new(idx, emb) + }) + .collect(); EmbeddingResponse::new( "gpt2".into(), data, - Usage { prompt_tokens: token_ids.len(), completion_tokens: 0, total_tokens: token_ids.len() }, + Usage { + prompt_tokens: token_ids.len(), + completion_tokens: 0, + total_tokens: token_ids.len(), + }, ) } diff --git a/src/hpc/gpt2/inference.rs b/src/hpc/gpt2/inference.rs index 2c828ec4..1cf06a7f 100644 --- a/src/hpc/gpt2/inference.rs +++ b/src/hpc/gpt2/inference.rs @@ -263,10 +263,7 @@ impl Gpt2Engine { if let Some(rt) = rt { if use_attn_table && t < self.token_history.len() { let key_token = self.token_history[t]; - let palette_sim = rt.heel_similarity( - current_token as usize, - key_token as usize, - ); + let palette_sim = rt.heel_similarity(current_token as usize, key_token as usize); // Blend: 90% matmul score + 10% palette shortcut scores[t] = scores[t] * 0.9 + palette_sim * 0.1 * scale; } @@ -284,12 +281,10 @@ impl Gpt2Engine { if scores[t] > 0.05 && t < self.token_history.len() { let key_token = self.token_history[t]; let edge = rt.pack_spo_edge( - current_token as usize, - head, // predicate = attention head - key_token as usize, - scores[t], // frequency = attention weight - 0.3, // initial confidence (low) - self.seq_len as u16, // temporal position + current_token as usize, head, // predicate = attention head + key_token as usize, scores[t], // frequency = attention weight + 0.3, // initial confidence (low) + self.seq_len as u16, // temporal position ); self.causal_edges.push(AttentionEdge { layer: layer_idx as u8, diff --git a/src/hpc/gpt2/weights.rs b/src/hpc/gpt2/weights.rs index f3ed4542..23a2d651 100644 --- a/src/hpc/gpt2/weights.rs +++ b/src/hpc/gpt2/weights.rs @@ -61,20 +61,18 @@ impl Gpt2Weights { // Safetensors format: [header_size:u64_le][header_json][tensor_data] let file = std::fs::read(path).map_err(|e| e.to_string())?; - let header_size = u64::from_le_bytes([ - file[0], file[1], file[2], file[3], - file[4], file[5], file[6], file[7], - ]) as usize; + let header_size = + u64::from_le_bytes([file[0], file[1], file[2], file[3], file[4], file[5], file[6], file[7]]) as usize; - let header_json = std::str::from_utf8(&file[8..8 + header_size]) - .map_err(|e| e.to_string())?; + let header_json = std::str::from_utf8(&file[8..8 + header_size]).map_err(|e| e.to_string())?; // Parse tensor metadata from JSON header let data_start = 8 + header_size; let tensors = parse_safetensors_header(header_json)?; let read_tensor = |name: &str| -> Result, String> { - let info = tensors.get(name) + let info = tensors + .get(name) .ok_or_else(|| format!("Missing tensor: {}", name))?; let start = data_start + info.offset; let end = start + info.size; @@ -112,7 +110,11 @@ impl Gpt2Weights { } let mut weights = Gpt2Weights { - wte, wpe, layers, ln_f_weight, ln_f_bias, + wte, + wpe, + layers, + ln_f_weight, + ln_f_bias, }; weights.transpose_weights_for_simd(); Ok(weights) @@ -180,14 +182,18 @@ fn parse_safetensors_header(json: &str) -> Result, S let arr_start = search_start + bracket_start + 1; if let Some(bracket_end) = json[arr_start..].find(']') { let arr = &json[arr_start..arr_start + bracket_end]; - let nums: Vec = arr.split(',') + let nums: Vec = arr + .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); if nums.len() == 2 { - tensors.insert(key.to_string(), TensorMeta { - offset: nums[0], - size: nums[1] - nums[0], - }); + tensors.insert( + key.to_string(), + TensorMeta { + offset: nums[0], + size: nums[1] - nums[0], + }, + ); } } } diff --git a/src/hpc/graph.rs b/src/hpc/graph.rs index 69a58347..51b600a9 100644 --- a/src/hpc/graph.rs +++ b/src/hpc/graph.rs @@ -3,9 +3,9 @@ //! Encodes directed edges as XOR bindings: edge = src ⊕ verb ⊕ tgt. //! Supports causality checking and verb inference. -use crate::imp_prelude::*; -use super::hdc::HdcOps; use super::bitwise::BitwiseOps; +use super::hdc::HdcOps; +use crate::imp_prelude::*; /// A VerbCodebook maps verb names to binary hypervectors. /// @@ -66,7 +66,9 @@ impl VerbCodebook { let mut v = Array::zeros(self.base_dim); let mut state = (offset as u64).wrapping_mul(2654435761); for byte in v.iter_mut() { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *byte = (state >> 33) as u8; } Some(v) @@ -79,10 +81,7 @@ impl VerbCodebook { /// # Errors /// Returns `Err` if the verb is not in the codebook. pub fn try_encode_edge( - &self, - src: &Array, - verb: &str, - tgt: &Array, + &self, src: &Array, verb: &str, tgt: &Array, ) -> Result, &'static str> { let verb_vec = self.verb_vector(verb).ok_or("Unknown verb")?; let offset = self.offset(verb).unwrap_or(1); @@ -91,21 +90,13 @@ impl VerbCodebook { } /// Encode an edge (panics on unknown verb). - pub fn encode_edge( - &self, - src: &Array, - verb: &str, - tgt: &Array, - ) -> Array { + pub fn encode_edge(&self, src: &Array, verb: &str, tgt: &Array) -> Array { self.try_encode_edge(src, verb, tgt).unwrap() } /// Decode target: tgt = edge ⊕ permute(src) ⊕ verb_vec pub fn try_decode_target( - &self, - edge: &Array, - src: &Array, - verb: &str, + &self, edge: &Array, src: &Array, verb: &str, ) -> Result, &'static str> { let verb_vec = self.verb_vector(verb).ok_or("Unknown verb")?; let offset = self.offset(verb).unwrap_or(1); @@ -114,24 +105,14 @@ impl VerbCodebook { } /// Decode target (panics on unknown verb). - pub fn decode_target( - &self, - edge: &Array, - src: &Array, - verb: &str, - ) -> Array { + pub fn decode_target(&self, edge: &Array, src: &Array, verb: &str) -> Array { self.try_decode_target(edge, src, verb).unwrap() } /// Causality asymmetry: measures how well edge(src→tgt) differs from edge(tgt→src). /// /// Returns a value between 0 (symmetric) and 1 (fully asymmetric). - pub fn causality_asymmetry( - &self, - src: &Array, - verb: &str, - tgt: &Array, - ) -> f64 { + pub fn causality_asymmetry(&self, src: &Array, verb: &str, tgt: &Array) -> f64 { let forward = self.encode_edge(src, verb, tgt); let backward = self.encode_edge(tgt, verb, src); let dist = forward.hamming_distance(&backward); @@ -141,10 +122,7 @@ impl VerbCodebook { /// Full causality check: returns (forward_edge, backward_edge, asymmetry). pub fn causality_check( - &self, - src: &Array, - verb: &str, - tgt: &Array, + &self, src: &Array, verb: &str, tgt: &Array, ) -> (Array, Array, f64) { let forward = self.encode_edge(src, verb, tgt); let backward = self.encode_edge(tgt, verb, src); @@ -158,9 +136,7 @@ impl VerbCodebook { /// Find edges with low causality asymmetry (potentially non-causal). pub fn find_non_causal_edges( - &self, - edges: &[(Array, &str, Array)], - threshold: f64, + &self, edges: &[(Array, &str, Array)], threshold: f64, ) -> Vec<(usize, f64)> { edges .iter() @@ -180,10 +156,7 @@ impl VerbCodebook { /// /// Returns (verb_name, verb_offset, hamming_distance). pub fn infer_verb( - &self, - edge: &Array, - src: &Array, - candidates: &[Array], + &self, edge: &Array, src: &Array, candidates: &[Array], ) -> Option<(String, usize, u64)> { if candidates.is_empty() { return None; @@ -206,19 +179,13 @@ impl VerbCodebook { } /// Encode an edge using explicit verb vector (no codebook needed). -pub fn encode_edge_explicit( - src: &Array, - verb_vec: &Array, - tgt: &Array, -) -> Array { +pub fn encode_edge_explicit(src: &Array, verb_vec: &Array, tgt: &Array) -> Array { src.hdc_bind(verb_vec).hdc_bind(tgt) } /// Decode target using explicit verb vector. pub fn decode_target_explicit( - edge: &Array, - src: &Array, - verb_vec: &Array, + edge: &Array, src: &Array, verb_vec: &Array, ) -> Array { edge.hdc_bind(src).hdc_bind(verb_vec) } diff --git a/src/hpc/hdc.rs b/src/hpc/hdc.rs index 003fd2af..8529bd5e 100644 --- a/src/hpc/hdc.rs +++ b/src/hpc/hdc.rs @@ -39,7 +39,8 @@ pub trait HdcOps { } impl HdcOps for ArrayBase -where S: Data +where + S: Data, { fn hdc_bind(&self, other: &Self) -> Array { let n = self.len().min(other.len()); diff --git a/src/hpc/heel_f64x8.rs b/src/hpc/heel_f64x8.rs index 2cb57cad..87ff42bb 100644 --- a/src/hpc/heel_f64x8.rs +++ b/src/hpc/heel_f64x8.rs @@ -41,11 +41,7 @@ pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] { /// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar. #[inline] -pub fn heel_weighted_hamming( - a_planes: &[u64; 8], - b_planes: &[u64; 8], - weights: &[f64; 8], -) -> f64 { +pub fn heel_weighted_hamming(a_planes: &[u64; 8], b_planes: &[u64; 8], weights: &[f64; 8]) -> f64 { let dists = heel_plane_distances(a_planes, b_planes); heel_weighted_distance(&dists, weights) } @@ -122,9 +118,9 @@ pub fn cosine_f64_simd(a: &[f64], b: &[f64]) -> f64 { for i in 0..chunks { let va = F64x8::from_slice(&a[i * 8..]); let vb = F64x8::from_slice(&b[i * 8..]); - dot_acc = va.mul_add(vb, dot_acc); // dot += a*b - na_acc = va.mul_add(va, na_acc); // na += a*a - nb_acc = vb.mul_add(vb, nb_acc); // nb += b*b + dot_acc = va.mul_add(vb, dot_acc); // dot += a*b + na_acc = va.mul_add(va, na_acc); // na += a*a + nb_acc = vb.mul_add(vb, nb_acc); // nb += b*b } let mut dot = dot_acc.reduce_sum(); @@ -139,7 +135,11 @@ pub fn cosine_f64_simd(a: &[f64], b: &[f64]) -> f64 { } let denom = (na * nb).sqrt(); - if denom < 1e-12 { 0.0 } else { dot / denom } + if denom < 1e-12 { + 0.0 + } else { + dot / denom + } } /// SIMD cosine similarity on f32 slices (converts to f64 internally for precision). @@ -185,7 +185,11 @@ pub fn cosine_f32_to_f64_simd(a: &[f32], b: &[f32]) -> f64 { } let denom = (na * nb).sqrt(); - if denom < 1e-12 { 0.0 } else { dot / denom } + if denom < 1e-12 { + 0.0 + } else { + dot / denom + } } #[cfg(test)] @@ -212,7 +216,9 @@ mod tests { fn plane_distances_self_zero() { let planes = [0x1234u64; 8]; let dists = heel_plane_distances(&planes, &planes); - for d in &dists { assert_eq!(*d, 0.0); } + for d in &dists { + assert_eq!(*d, 0.0); + } } #[test] @@ -220,7 +226,9 @@ mod tests { let a = [0u64; 8]; let b = [u64::MAX; 8]; let dists = heel_plane_distances(&a, &b); - for d in &dists { assert_eq!(*d, 64.0); } + for d in &dists { + assert_eq!(*d, 64.0); + } } #[test] @@ -279,8 +287,7 @@ mod tests { let nb: f64 = b.iter().map(|x| x * x).sum(); let scalar_cos = dot / (na * nb).sqrt(); - assert!((simd_cos - scalar_cos).abs() < 1e-10, - "SIMD {:.12} vs scalar {:.12}", simd_cos, scalar_cos); + assert!((simd_cos - scalar_cos).abs() < 1e-10, "SIMD {:.12} vs scalar {:.12}", simd_cos, scalar_cos); } #[test] @@ -294,8 +301,7 @@ mod tests { let cos_f64 = cosine_f64_simd(&a_f64, &b_f64); let cos_f32 = cosine_f32_to_f64_simd(&a_f32, &b_f32); - assert!((cos_f64 - cos_f32).abs() < 1e-6, - "f32 {:.10} vs f64 {:.10}", cos_f32, cos_f64); + assert!((cos_f64 - cos_f32).abs() < 1e-6, "f32 {:.10} vs f64 {:.10}", cos_f32, cos_f64); } #[test] diff --git a/src/hpc/holo.rs b/src/hpc/holo.rs index 292af5a1..60ea7d99 100644 --- a/src/hpc/holo.rs +++ b/src/hpc/holo.rs @@ -20,7 +20,6 @@ use std::f64::consts::PI; // Binding = addition mod 256 (VPADDB). Unbinding = subtraction mod 256 (VPSUBB). // Unlike binary XOR, phase operations preserve spatial locality. - // ------------------------------------------------------------------------- // Operation 1: phase_bind_i8 // ------------------------------------------------------------------------- @@ -91,11 +90,7 @@ pub fn wasserstein_sorted_i8(a: &[u8], b: &[u8]) -> u64 { /// Stage 2: sample 1/4, reject at 2σ /// Stage 3: full Wasserstein on survivors pub fn wasserstein_search_adaptive( - query: &[u8], - database: &[u8], - vec_len: usize, - n: usize, - max_distance: u64, + query: &[u8], database: &[u8], vec_len: usize, n: usize, max_distance: u64, ) -> Vec<(usize, u64)> { let mut results = Vec::new(); let sample_16 = vec_len / 16; @@ -642,10 +637,7 @@ mod phase_tests { let (sb, _) = sort_phase_vector(&b); let w = wasserstein_sorted_i8(&sa, &sb); - assert!( - w > 0, - "distinct random vectors should have nonzero Wasserstein" - ); + assert!(w > 0, "distinct random vectors should have nonzero Wasserstein"); } #[test] @@ -736,8 +728,6 @@ mod phase_tests { // which uses **u8** (unsigned, each byte = an angle on [0°, 360°)). // Binary containers (META, BTREE) remain unchanged. - - // ============================================================================ // Constants // ============================================================================ @@ -746,9 +736,7 @@ mod phase_tests { /// /// If f1=5 and f2=10, then f2 is the 2nd harmonic of f1 — they interfere. /// Fibonacci spacing avoids integer-ratio relationships between any pair. -pub const CARRIER_FREQUENCIES: [u16; 16] = [ - 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1024, -]; +pub const CARRIER_FREQUENCIES: [u16; 16] = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1024]; /// Per-carrier amplitude. With 16 carriers superimposed in i8 (-128..+127): /// max amplitude per carrier = 127 / 16 ≈ 7 @@ -849,13 +837,7 @@ impl CarrierBasis { /// /// Uses float per element for phase precision. Maps to VCVTDQ2PS + VMULPS + /// VCVTPS2DQ on AVX-512 (~128 instructions for 2048 bytes). -pub fn carrier_encode( - container: &mut [i8], - basis: &CarrierBasis, - freq_idx: u8, - phase_offset: f32, - amplitude: f32, -) { +pub fn carrier_encode(container: &mut [i8], basis: &CarrierBasis, freq_idx: u8, phase_offset: f32, amplitude: f32) { assert_eq!(container.len(), 2048); assert!((freq_idx as usize) < CARRIER_FREQUENCIES.len()); @@ -1090,23 +1072,15 @@ impl CarrierRecord { /// 4-channel hybrid sweep. /// META + BTREE: Hamming (same as CogRecordV3). /// CAM + EMBED: carrier L1 distance. - pub fn hybrid_sweep( - &self, - other: &Self, - thresholds: &CarrierThresholds, - ) -> Option { + pub fn hybrid_sweep(&self, other: &Self, thresholds: &CarrierThresholds) -> Option { // Stage 1: META — binary Hamming (cheapest rejection) - let meta_dist = - super::bitwise::hamming_distance_raw(self.meta.as_slice(), other.meta.as_slice()); + let meta_dist = super::bitwise::hamming_distance_raw(self.meta.as_slice(), other.meta.as_slice()); if meta_dist > thresholds.meta_hamming { return None; } // Stage 2: BTREE — binary Hamming - let btree_dist = super::bitwise::hamming_distance_raw( - self.btree.as_slice(), - other.btree.as_slice(), - ); + let btree_dist = super::bitwise::hamming_distance_raw(self.btree.as_slice(), other.btree.as_slice()); if btree_dist > thresholds.btree_hamming { return None; } @@ -1132,11 +1106,7 @@ impl CarrierRecord { } /// Batch hybrid sweep against a database of CarrierRecords. - pub fn hybrid_search( - &self, - database: &[Self], - thresholds: &CarrierThresholds, - ) -> Vec<(usize, CarrierDistances)> { + pub fn hybrid_search(&self, database: &[Self], thresholds: &CarrierThresholds) -> Vec<(usize, CarrierDistances)> { database .iter() .enumerate() @@ -1240,8 +1210,7 @@ mod carrier_tests { .zip(basis.basis_cos[j].iter()) .map(|(&a, &b)| a as i64 * b as i64) .sum(); - let normalized = (dot as f64).abs() - / (2048.0 * CARRIER_AMPLITUDE as f64 * CARRIER_AMPLITUDE as f64); + let normalized = (dot as f64).abs() / (2048.0 * CARRIER_AMPLITUDE as f64 * CARRIER_AMPLITUDE as f64); assert!( normalized < 0.15, "cos[{}] and cos[{}] should be orthogonal, dot/norm = {:.4}", @@ -1287,11 +1256,7 @@ mod carrier_tests { let mut waveform = vec![0i8; 2048]; carrier_encode(&mut waveform, &basis, 0, 0.0, CARRIER_AMPLITUDE); let (phase, amp) = carrier_decode(&waveform, &basis, 0); - assert!( - phase_error(phase, 0.0) < 0.15, - "phase=0 recovery: got {:.4}", - phase - ); + assert!(phase_error(phase, 0.0) < 0.15, "phase=0 recovery: got {:.4}", phase); assert!(amp > 1.0, "amplitude should be significant, got {:.4}", amp); } @@ -1299,13 +1264,7 @@ mod carrier_tests { fn test_encode_decode_phase_pi() { let basis = CarrierBasis::new(); let mut waveform = vec![0i8; 2048]; - carrier_encode( - &mut waveform, - &basis, - 0, - std::f32::consts::PI, - CARRIER_AMPLITUDE, - ); + carrier_encode(&mut waveform, &basis, 0, std::f32::consts::PI, CARRIER_AMPLITUDE); let (phase, amp) = carrier_decode(&waveform, &basis, 0); assert!( phase_error(phase, std::f32::consts::PI) < 0.15, @@ -1323,12 +1282,7 @@ mod carrier_tests { let target = TAU - 0.01; carrier_encode(&mut waveform, &basis, 0, target, CARRIER_AMPLITUDE); let (phase, _) = carrier_decode(&waveform, &basis, 0); - assert!( - phase_error(phase, target) < 0.15, - "wrap-around recovery: got {:.4}, expected {:.4}", - phase, - target - ); + assert!(phase_error(phase, target) < 0.15, "wrap-around recovery: got {:.4}, expected {:.4}", phase, target); } #[test] @@ -1345,18 +1299,8 @@ mod carrier_tests { let (rec_a, _) = carrier_decode(&waveform, &basis, 0); let (rec_b, _) = carrier_decode(&waveform, &basis, 5); - assert!( - phase_error(rec_a, phase_a) < 0.2, - "carrier 0: expected {:.4}, got {:.4}", - phase_a, - rec_a - ); - assert!( - phase_error(rec_b, phase_b) < 0.2, - "carrier 5: expected {:.4}, got {:.4}", - phase_b, - rec_b - ); + assert!(phase_error(rec_a, phase_a) < 0.2, "carrier 0: expected {:.4}, got {:.4}", phase_a, rec_a); + assert!(phase_error(rec_b, phase_b) < 0.2, "carrier 5: expected {:.4}, got {:.4}", phase_b, rec_b); } #[test] @@ -1418,16 +1362,8 @@ mod carrier_tests { let (rec_a, _) = carrier_decode(&bundled, &basis, 0); let (rec_b, _) = carrier_decode(&bundled, &basis, 5); - assert!( - phase_error(rec_a, 1.0) < 0.25, - "bundled carrier 0: expected 1.0, got {:.4}", - rec_a - ); - assert!( - phase_error(rec_b, 2.5) < 0.25, - "bundled carrier 5: expected 2.5, got {:.4}", - rec_b - ); + assert!(phase_error(rec_a, 1.0) < 0.25, "bundled carrier 0: expected 1.0, got {:.4}", rec_a); + assert!(phase_error(rec_b, 2.5) < 0.25, "bundled carrier 5: expected 2.5, got {:.4}", rec_b); } #[test] @@ -1471,13 +1407,7 @@ mod carrier_tests { let mut waveforms: Vec> = Vec::new(); for i in 0..n { let mut wf = vec![0i8; 2048]; - carrier_encode( - &mut wf, - &basis, - (i % 16) as u8, - phases[i], - CARRIER_AMPLITUDE, - ); + carrier_encode(&mut wf, &basis, (i % 16) as u8, phases[i], CARRIER_AMPLITUDE); waveforms.push(wf); } @@ -1496,11 +1426,7 @@ mod carrier_tests { let mean_single = single_freq_errors.iter().sum::() / single_freq_errors.len() as f32; // Unshared frequencies should still work - assert!( - mean_single < 0.6, - "unshared frequencies at N=21 mean error = {:.4} rad", - mean_single - ); + assert!(mean_single < 0.6, "unshared frequencies at N=21 mean error = {:.4} rad", mean_single); } // ---- Distance tests ---- @@ -1524,11 +1450,7 @@ mod carrier_tests { let mut wf = vec![0i8; 2048]; carrier_encode(&mut wf, &basis, 0, 1.0, CARRIER_AMPLITUDE); let corr = carrier_correlation(&wf, &wf); - assert!( - (corr - 1.0).abs() < 0.01, - "self-correlation should be 1.0, got {:.4}", - corr - ); + assert!((corr - 1.0).abs() < 0.01, "self-correlation should be 1.0, got {:.4}", corr); } #[test] @@ -1538,11 +1460,7 @@ mod carrier_tests { carrier_encode(&mut wf, &basis, 0, 1.0, CARRIER_AMPLITUDE); let neg: Vec = wf.iter().map(|&v| v.saturating_neg()).collect(); let corr = carrier_correlation(&wf, &neg); - assert!( - (corr + 1.0).abs() < 0.05, - "negation correlation should be -1.0, got {:.4}", - corr - ); + assert!((corr + 1.0).abs() < 0.05, "negation correlation should be -1.0, got {:.4}", corr); } #[test] @@ -1554,11 +1472,7 @@ mod carrier_tests { carrier_encode(&mut wf_b, &basis, 5, 0.0, CARRIER_AMPLITUDE); let corr = carrier_correlation(&wf_a, &wf_b); - assert!( - corr.abs() < 0.2, - "orthogonal carriers should have near-zero correlation, got {:.4}", - corr - ); + assert!(corr.abs() < 0.2, "orthogonal carriers should have near-zero correlation, got {:.4}", corr); } // ---- Spectrum tests ---- @@ -1718,13 +1632,7 @@ mod carrier_tests { let phases: Vec = (0..n).map(|i| (i as f32) * 0.7 + 0.3).collect(); for i in 0..n as usize { - carrier_encode( - &mut carrier_waveform, - &basis, - i as u8 % 16, - phases[i], - CARRIER_AMPLITUDE, - ); + carrier_encode(&mut carrier_waveform, &basis, i as u8 % 16, phases[i], CARRIER_AMPLITUDE); } let mut carrier_errors = Vec::new(); @@ -1735,10 +1643,8 @@ mod carrier_tests { carrier_amps.push(rec_amp); } - let carrier_mean_error: f32 = - carrier_errors.iter().sum::() / carrier_errors.len() as f32; - let carrier_mean_amp: f32 = - carrier_amps.iter().sum::() / carrier_amps.len() as f32; + let carrier_mean_error: f32 = carrier_errors.iter().sum::() / carrier_errors.len() as f32; + let carrier_mean_amp: f32 = carrier_amps.iter().sum::() / carrier_amps.len() as f32; // --- Random-phase path (using phase.rs functions) --- use self::{circular_distance_i8, phase_bundle_circular, phase_unbind_i8}; @@ -1758,8 +1664,7 @@ mod carrier_tests { let dist = circular_distance_i8(&recovered, &phase_vecs[0]); phase_errors.push(dist); } - let phase_self_recovery = - circular_distance_i8(&phase_unbind_i8(&bundle, &phase_vecs[0]), &phase_vecs[0]); + let phase_self_recovery = circular_distance_i8(&phase_unbind_i8(&bundle, &phase_vecs[0]), &phase_vecs[0]); println!( "N={:>2}: carrier_err={:.4} rad ({:>5.1}°) amp={:.2} | phase_self_dist={}", @@ -1784,11 +1689,7 @@ mod carrier_tests { total_err += phase_error(rec, phases[i]); } let mean = total_err / 16.0; - assert!( - mean < 0.5, - "Carrier at N=16 mean error = {:.4} rad — capacity limit exceeded", - mean - ); + assert!(mean < 0.5, "Carrier at N=16 mean error = {:.4} rad — capacity limit exceeded", mean); } } @@ -1800,10 +1701,7 @@ mod carrier_tests { let u8_carrier = basis.cos_as_u8(0); assert_eq!(u8_carrier.len(), 2048); // First sample: cos[0][0] = amplitude (7), offset = 7+128 = 135 - assert_eq!( - u8_carrier[0], - (CARRIER_AMPLITUDE.round() as u8).wrapping_add(128) - ); + assert_eq!(u8_carrier[0], (CARRIER_AMPLITUDE.round() as u8).wrapping_add(128)); } } @@ -1834,7 +1732,6 @@ mod carrier_tests { // Three masks: `mask_x: u8`, `mask_y: u8`, `mask_z: u32` = 48 bits. // A byte is "in focus" only if ALL THREE masks select its slab. - // ============================================================================ // Constants // ============================================================================ @@ -2176,36 +2073,18 @@ pub fn focus_xor_auto(container: &mut [u8], mask_x: u8, mask_y: u8, mask_z: u32, /// Write a concept into a binary container at a focused region. /// Uses XOR binding. Self-inverse: call again to erase. -pub fn focus_bind_binary( - container: &mut [u8], - mask_x: u8, - mask_y: u8, - mask_z: u32, - concept_vec: &[u8], -) { +pub fn focus_bind_binary(container: &mut [u8], mask_x: u8, mask_y: u8, mask_z: u32, concept_vec: &[u8]) { focus_xor(container, mask_x, mask_y, mask_z, concept_vec); } /// Write a concept into a phase container at a focused region. /// Uses ADD binding. NOT self-inverse — use focus_unbind_phase to erase. -pub fn focus_bind_phase( - container: &mut [u8], - mask_x: u8, - mask_y: u8, - mask_z: u32, - concept_vec: &[u8], -) { +pub fn focus_bind_phase(container: &mut [u8], mask_x: u8, mask_y: u8, mask_z: u32, concept_vec: &[u8]) { focus_add(container, mask_x, mask_y, mask_z, concept_vec); } /// Erase a concept from a phase container at a focused region. -pub fn focus_unbind_phase( - container: &mut [u8], - mask_x: u8, - mask_y: u8, - mask_z: u32, - concept_vec: &[u8], -) { +pub fn focus_unbind_phase(container: &mut [u8], mask_x: u8, mask_y: u8, mask_z: u32, concept_vec: &[u8]) { focus_sub(container, mask_x, mask_y, mask_z, concept_vec); } @@ -2215,13 +2094,7 @@ pub fn focus_unbind_phase( /// (spatial partitioning). The carrier signal only exists in the focused /// region. pub fn focus_carrier_encode( - container: &mut [i8], - basis: &CarrierBasis, - mask_x: u8, - mask_y: u8, - mask_z: u32, - freq_idx: u8, - phase_offset: f32, + container: &mut [i8], basis: &CarrierBasis, mask_x: u8, mask_y: u8, mask_z: u32, freq_idx: u8, phase_offset: f32, amplitude: f32, ) { let cos_phi = phase_offset.cos(); @@ -2338,9 +2211,7 @@ impl Default for FocusRegistry { impl FocusRegistry { pub fn new() -> Self { - Self { - entries: Vec::new(), - } + Self { entries: Vec::new() } } /// Register a concept at a focus address. @@ -2350,12 +2221,7 @@ impl FocusRegistry { /// Check if a proposed focus address overlaps with any existing entry. /// Returns overlapping (concept_id, overlap_size_bytes) pairs. - pub fn check_overlap( - &self, - new_mask_x: u8, - new_mask_y: u8, - new_mask_z: u32, - ) -> Vec<(u64, u32)> { + pub fn check_overlap(&self, new_mask_x: u8, new_mask_y: u8, new_mask_z: u32) -> Vec<(u64, u32)> { let mut overlaps = Vec::new(); for &(existing_packed, concept_id) in &self.entries { @@ -2475,8 +2341,7 @@ mod focus_tests { for (mx, my, mz) in test_cases { let mask = materialize_focus_mask(mx, my, mz); let count = mask.iter().filter(|&&b| b == 0xFF).count(); - let expected = - mx.count_ones() as usize * my.count_ones() as usize * mz.count_ones() as usize; + let expected = mx.count_ones() as usize * my.count_ones() as usize * mz.count_ones() as usize; assert_eq!(count, expected, "mx={:#x} my={:#x} mz={:#x}", mx, my, mz); } } @@ -2545,11 +2410,7 @@ mod focus_tests { let mask = materialize_focus_mask(mx, my, mz); for i in 0..2048 { if mask[i] == 0 { - assert_eq!( - container[i], original[i], - "position {} outside mask changed", - i - ); + assert_eq!(container[i], original[i], "position {} outside mask changed", i); } } } @@ -2805,20 +2666,12 @@ mod focus_tests { masks.insert((mx, my, mz)); } // With 100 random IDs and medium density, most should be distinct - assert!( - masks.len() > 50, - "expected most masks unique, got {}", - masks.len() - ); + assert!(masks.len() > 50, "expected most masks unique, got {}", masks.len()); } #[test] fn test_concept_to_focus_density_bits() { - for density in [ - FocusDensity::Sparse, - FocusDensity::Medium, - FocusDensity::Broad, - ] { + for density in [FocusDensity::Sparse, FocusDensity::Medium, FocusDensity::Broad] { let (bits_x, bits_y, bits_z) = density.bit_counts(); let (mx, my, mz) = concept_to_focus(42, density); assert_eq!(mx.count_ones(), bits_x, "density={:?} mask_x bits", density); @@ -2866,13 +2719,7 @@ mod focus_tests { } // With sparse non-overlapping, most/all should match // (some may collide due to birthday effect) - assert!( - matches as f64 / total as f64 > 0.5, - "concept {} signal too weak: {}/{}", - id, - matches, - total - ); + assert!(matches as f64 / total as f64 > 0.5, "concept {} signal too weak: {}/{}", id, matches, total); } } @@ -2968,14 +2815,8 @@ mod focus_tests { let delta = focus_delta(&old, &new, mx, my, mz); let compact = CompactDelta::from_delta(&delta, mx, my, mz); - assert!( - compact.wire_size() < 2048, - "compact should be smaller than full" - ); - assert!( - compact.changes.len() <= 4, - "sparse focus: at most 4 changes" - ); + assert!(compact.wire_size() < 2048, "compact should be smaller than full"); + assert!(compact.changes.len() <= 4, "sparse focus: at most 4 changes"); } #[test] @@ -3030,11 +2871,7 @@ mod focus_tests { fn test_focus_capacity_experiment() { println!("\n=== Focus Gating Capacity Experiment ===\n"); - for &density in &[ - FocusDensity::Sparse, - FocusDensity::Medium, - FocusDensity::Broad, - ] { + for &density in &[FocusDensity::Sparse, FocusDensity::Medium, FocusDensity::Broad] { let (bits_x, bits_y, bits_z) = density.bit_counts(); let region_bytes = bits_x * bits_y * bits_z; @@ -3125,11 +2962,7 @@ mod focus_tests { } } let accuracy = total_match as f64 / total_bits as f64; - assert!( - accuracy > 0.7, - "Sparse N=10 accuracy {:.1}% too low", - accuracy * 100.0 - ); + assert!(accuracy > 0.7, "Sparse N=10 accuracy {:.1}% too low", accuracy * 100.0); } } @@ -3144,26 +2977,13 @@ mod focus_tests { let my = 0x01u8; let mz = 0x0000000Fu32; // 1×1×4 = 4 bytes - focus_carrier_encode( - &mut container, - &basis, - mx, - my, - mz, - 0, - 1.0, - self::CARRIER_AMPLITUDE, - ); + focus_carrier_encode(&mut container, &basis, mx, my, mz, 0, 1.0, self::CARRIER_AMPLITUDE); // Check that only masked positions are non-zero let mask = materialize_focus_mask(mx, my, mz); for i in 0..2048 { if mask[i] == 0 { - assert_eq!( - container[i], 0, - "position {} outside mask should be zero, got {}", - i, container[i] - ); + assert_eq!(container[i], 0, "position {} outside mask should be zero, got {}", i, container[i]); } } @@ -3171,9 +2991,6 @@ mod focus_tests { let nonzero_in_mask = (0..2048) .filter(|&i| mask[i] == 0xFF && container[i] != 0) .count(); - assert!( - nonzero_in_mask > 0, - "carrier should write some non-zero values" - ); + assert!(nonzero_in_mask > 0, "carrier should write some non-zero values"); } } diff --git a/src/hpc/http_reader.rs b/src/hpc/http_reader.rs index 9183e423..246f63ec 100644 --- a/src/hpc/http_reader.rs +++ b/src/hpc/http_reader.rs @@ -24,8 +24,8 @@ use std::process::{Command, Stdio}; /// within the cache window are free (no re-fetch). pub struct HttpRangeReader { url: String, - repo: Option, // for re-resolve on 403 - filename: Option, // for re-resolve on 403 + repo: Option, // for re-resolve on 403 + filename: Option, // for re-resolve on 403 position: u64, total_size: u64, chunk_size: usize, @@ -150,10 +150,7 @@ impl HttpRangeReader { Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), } } else { - Err(io::Error::new( - io::ErrorKind::Other, - "cannot re-resolve: no repo/filename stored (use from_hf())", - )) + Err(io::Error::new(io::ErrorKind::Other, "cannot re-resolve: no repo/filename stored (use from_hf())")) } } @@ -174,8 +171,7 @@ impl HttpRangeReader { for attempt in 0..MAX_RETRIES { if attempt > 0 { let delay = INITIAL_BACKOFF_MS * (1u64 << (attempt - 1).min(4)); - eprintln!(" retry {}/{} after {}ms (segment {}-{})", - attempt + 1, MAX_RETRIES, delay, start, end); + eprintln!(" retry {}/{} after {}ms (segment {}-{})", attempt + 1, MAX_RETRIES, delay, start, end); std::thread::sleep(std::time::Duration::from_millis(delay)); } @@ -184,15 +180,11 @@ impl HttpRangeReader { let result = Command::new("curl") .args(&[ - "-sL", - "--retry", "2", - "--retry-delay", "2", - "--connect-timeout", "30", - "--max-time", "600", // 10 min max per 64 MB segment + "-sL", "--retry", "2", "--retry-delay", "2", "--connect-timeout", "30", "--max-time", + "600", // 10 min max per 64 MB segment "--speed-limit", &speed_limit_str, // abort if < 100 KB/s - "--speed-time", &speed_time_str, // for > 30 seconds - "-r", &range, - &self.url, + "--speed-time", &speed_time_str, // for > 30 seconds + "-r", &range, &self.url, ]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -213,13 +205,15 @@ impl HttpRangeReader { Ok(output) => { let stderr = String::from_utf8_lossy(&output.stderr); let code = output.status.code().unwrap_or(-1); - eprintln!(" fetch failed: exit={} got={} bytes stderr={}", - code, output.stdout.len(), stderr.trim()); + eprintln!( + " fetch failed: exit={} got={} bytes stderr={}", + code, + output.stdout.len(), + stderr.trim() + ); // 403 or curl exit 22 (HTTP error) → re-resolve CDN URL - if (code == 22 || stderr.contains("403") || stderr.contains("expired")) - && !resolved_this_call - { + if (code == 22 || stderr.contains("403") || stderr.contains("expired")) && !resolved_this_call { if self.re_resolve_url().is_ok() { resolved_this_call = true; eprintln!(" URL re-resolved, retrying immediately"); @@ -325,10 +319,7 @@ impl Seek for HttpRangeReader { }; if new_pos < 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "seek before start of file", - )); + return Err(io::Error::new(io::ErrorKind::InvalidInput, "seek before start of file")); } self.position = new_pos as u64; @@ -351,13 +342,16 @@ impl Seek for HttpRangeReader { pub fn resolve_hf_url(repo: &str, filename: &str) -> Result<(String, u64), String> { // Method 1: Python huggingface_hub (handles auth tokens, gated models) if let Ok(py_out) = Command::new("python3") - .args(&["-c", &format!( - "from huggingface_hub import hf_hub_url, get_hf_file_metadata; \ + .args(&[ + "-c", + &format!( + "from huggingface_hub import hf_hub_url, get_hf_file_metadata; \ url = hf_hub_url('{}', '{}'); \ meta = get_hf_file_metadata(url); \ print(meta.size); print(meta.location if hasattr(meta, 'location') else url)", - repo, filename - )]) + repo, filename + ), + ]) .output() { if py_out.status.success() { @@ -376,10 +370,7 @@ pub fn resolve_hf_url(repo: &str, filename: &str) -> Result<(String, u64), Strin } // Method 2: curl HEAD with redirect follow - let url = format!( - "https://huggingface.co/{}/resolve/main/{}", - repo, filename - ); + let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename); if let Ok(output) = Command::new("curl") .args(&["-sIL", "--connect-timeout", "15", "--max-time", "30", &url]) @@ -416,7 +407,8 @@ pub fn resolve_hf_url(repo: &str, filename: &str) -> Result<(String, u64), Strin // Method 3: HuggingFace Hub REST API (no Python needed) let api_url = format!( "https://huggingface.co/api/models/{}/tree/main/{}", - repo, filename.rsplit('/').next().unwrap_or(filename) + repo, + filename.rsplit('/').next().unwrap_or(filename) ); if let Ok(output) = Command::new("curl") .args(&["-sL", "--connect-timeout", "15", "--max-time", "30", &api_url]) @@ -489,7 +481,8 @@ mod tests { let (url, size) = resolve_hf_url( "unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF", "BF16/Llama-4-Scout-17B-16E-Instruct-BF16-00005-of-00005.gguf", - ).expect("resolve_hf_url"); + ) + .expect("resolve_hf_url"); assert!(size > 0, "size should be > 0"); assert!(url.contains("http"), "url should be HTTP: {}", url); eprintln!("resolved: {} ({} bytes)", url, size); diff --git a/src/hpc/jina/cache.rs b/src/hpc/jina/cache.rs index 2b9dec51..b94d179c 100644 --- a/src/hpc/jina/cache.rs +++ b/src/hpc/jina/cache.rs @@ -9,13 +9,17 @@ use std::io::{Read, Write}; /// Save Base17 tokens to binary cache. pub fn save_base17_cache(tokens: &[Base17Token], writer: &mut W) -> Result<(), String> { let n = tokens.len() as u32; - writer.write_all(&n.to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&n.to_le_bytes()) + .map_err(|e| e.to_string())?; writer .write_all(&(BASE_DIM as u32).to_le_bytes()) .map_err(|e| e.to_string())?; for t in tokens { for &d in &t.dims { - writer.write_all(&d.to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&d.to_le_bytes()) + .map_err(|e| e.to_string())?; } } Ok(()) @@ -48,7 +52,9 @@ pub fn load_base17_cache(reader: &mut R) -> Result, St /// Save palette to binary cache. pub fn save_palette_cache(palette: &JinaPalette, writer: &mut W) -> Result<(), String> { let n = palette.assignments.len() as u32; - writer.write_all(&n.to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&n.to_le_bytes()) + .map_err(|e| e.to_string())?; writer .write_all(&(BASE_DIM as u32).to_le_bytes()) .map_err(|e| e.to_string())?; @@ -59,7 +65,9 @@ pub fn save_palette_cache(palette: &JinaPalette, writer: &mut W) -> Re // Centroids for k in 0..PALETTE_K { for &d in &palette.centroids[k].dims { - writer.write_all(&d.to_le_bytes()).map_err(|e| e.to_string())?; + writer + .write_all(&d.to_le_bytes()) + .map_err(|e| e.to_string())?; } } // Assignments diff --git a/src/hpc/jina/causal.rs b/src/hpc/jina/causal.rs index 360315dd..3109952e 100644 --- a/src/hpc/jina/causal.rs +++ b/src/hpc/jina/causal.rs @@ -13,13 +13,7 @@ use super::codec::JinaPalette; /// ``` #[inline(always)] pub fn pack_edge( - s_palette: u8, - p_palette: u8, - o_palette: u8, - frequency: f32, - confidence: f32, - pearl_mask: u8, - temporal: u16, + s_palette: u8, p_palette: u8, o_palette: u8, frequency: f32, confidence: f32, pearl_mask: u8, temporal: u16, ) -> u64 { let f_u8 = (frequency.clamp(0.0, 1.0) * 255.0) as u8; let c_u8 = (confidence.clamp(0.0, 1.0) * 255.0) as u8; @@ -216,8 +210,11 @@ mod tests { assert!(d_spo > d_po, "removing S should reduce distance"); assert!(d_spo > d_so, "removing P should reduce distance"); - assert_eq!(d_spo, d_s + d_po - causal_distance(e1, e2, &palette, 0b000), - "planes should be additive (within rounding)... actually just check ordering"); + assert_eq!( + d_spo, + d_s + d_po - causal_distance(e1, e2, &palette, 0b000), + "planes should be additive (within rounding)... actually just check ordering" + ); assert!(d_s > 0, "different S should have positive distance"); } } diff --git a/src/hpc/jina/codec.rs b/src/hpc/jina/codec.rs index daebc1c1..fe995d85 100644 --- a/src/hpc/jina/codec.rs +++ b/src/hpc/jina/codec.rs @@ -133,8 +133,7 @@ impl JinaPalette { for k in 0..PALETTE_K { if counts[k] > 0 { for d in 0..BASE_DIM { - centroids[k].dims[d] = - (sums[k][d] / counts[k] as i64).clamp(-32768, 32767) as i16; + centroids[k].dims[d] = (sums[k][d] / counts[k] as i64).clamp(-32768, 32767) as i16; } } } @@ -160,8 +159,7 @@ impl JinaPalette { /// O(1) distance between two tokens via palette lookup. #[inline(always)] pub fn distance(&self, token_a: usize, token_b: usize) -> u16 { - self.distance_table[self.assignments[token_a] as usize] - [self.assignments[token_b] as usize] + self.distance_table[self.assignments[token_a] as usize][self.assignments[token_b] as usize] } /// Palette index for a token. @@ -232,11 +230,7 @@ mod tests { let palette = JinaPalette::build(&tokens, 5); for i in 0..PALETTE_K { for j in 0..PALETTE_K { - assert_eq!( - palette.distance_table[i][j], - palette.distance_table[j][i], - "Asymmetric at [{i}][{j}]" - ); + assert_eq!(palette.distance_table[i][j], palette.distance_table[j][i], "Asymmetric at [{i}][{j}]"); } } } diff --git a/src/hpc/jina/runtime.rs b/src/hpc/jina/runtime.rs index c7bb44f0..398d2930 100644 --- a/src/hpc/jina/runtime.rs +++ b/src/hpc/jina/runtime.rs @@ -141,10 +141,9 @@ pub struct ModelRuntime { impl ModelRuntime { /// Load from embedded weight bytes. fn load(source: ModelSource, base17_bytes: &[u8], palette_bytes: &[u8]) -> Self { - let tokens = load_base17_cache(&mut std::io::Cursor::new(base17_bytes)) - .expect("Failed to load Base17 cache"); - let palette = load_palette_cache(&mut std::io::Cursor::new(palette_bytes)) - .expect("Failed to load palette cache"); + let tokens = load_base17_cache(&mut std::io::Cursor::new(base17_bytes)).expect("Failed to load Base17 cache"); + let palette = + load_palette_cache(&mut std::io::Cursor::new(palette_bytes)).expect("Failed to load palette cache"); // Build SimilarityTable from the EXACT 256×256 distance distribution. // This IS the bgz17 SimilarityTable pattern: empirical CDF → calibrated f32. @@ -199,12 +198,7 @@ impl ModelRuntime { /// Pack two tokens + a predicate into a CausalEdge64. #[inline] pub fn pack_spo_edge( - &self, - subject_token: usize, - predicate_token: usize, - object_token: usize, - frequency: f32, - confidence: f32, + &self, subject_token: usize, predicate_token: usize, object_token: usize, frequency: f32, confidence: f32, temporal: u16, ) -> u64 { causal::pack_edge( @@ -309,19 +303,16 @@ pub static JINA: LazyLock = LazyLock::new(|| { /// silently upgraded to v5 when the main route is swapped. Today this is /// functionally identical to `JINA` (both load v4 bytes), but after the v5 /// bake `JINA` will load v5 while `JINA_V4` keeps loading v4. -pub static JINA_V4: LazyLock = LazyLock::new(|| { - ModelRuntime::load(ModelSource::JinaV4, JINA_V4_BASE17, JINA_V4_PALETTE) -}); +pub static JINA_V4: LazyLock = + LazyLock::new(|| ModelRuntime::load(ModelSource::JinaV4, JINA_V4_BASE17, JINA_V4_PALETTE)); /// GPT-2 runtime (50K tokens). Same BPE as Jina → interoperable palettes. -pub static GPT2: LazyLock = LazyLock::new(|| { - ModelRuntime::load(ModelSource::Gpt2, GPT2_BASE17, GPT2_PALETTE) -}); +pub static GPT2: LazyLock = + LazyLock::new(|| ModelRuntime::load(ModelSource::Gpt2, GPT2_BASE17, GPT2_PALETTE)); /// BERT runtime (30K tokens). WordPiece tokenizer (different from GPT-2 BPE). -pub static BERT: LazyLock = LazyLock::new(|| { - ModelRuntime::load(ModelSource::Bert, BERT_BASE17, BERT_PALETTE) -}); +pub static BERT: LazyLock = + LazyLock::new(|| ModelRuntime::load(ModelSource::Bert, BERT_BASE17, BERT_PALETTE)); #[cfg(test)] mod tests { diff --git a/src/hpc/jitson/mod.rs b/src/hpc/jitson/mod.rs index 8cd5643c..2393d46c 100644 --- a/src/hpc/jitson/mod.rs +++ b/src/hpc/jitson/mod.rs @@ -37,18 +37,16 @@ pub use validator::{validate, ValidationError}; // Re-exports: template layer pub use template::{ - from_json, check_pipeline_features, template_hash, - JitsonTemplate, PipelineStage, BackendConfig, JitsonError, + check_pipeline_features, from_json, template_hash, BackendConfig, JitsonError, JitsonTemplate, PipelineStage, }; // Re-exports: precompile queue -pub use precompile::{PrecompileQueue, PrecompileEntry, CompileState}; +pub use precompile::{CompileState, PrecompileEntry, PrecompileQueue}; // Re-exports: scan config + SIMD trampolines pub use scan_config::{ - ScanConfig, ScanResult, SimdKernelRegistry, DefaultKernelRegistry, - scan_hamming, jit_symbol_table, + jit_symbol_table, scan_hamming, DefaultKernelRegistry, ScanConfig, ScanResult, SimdKernelRegistry, }; // Re-exports: noise parameters + terrain templates -pub use noise::{NoiseParams, GRAD3, simple_noise_3d, CompiledNoiseConfig, TerrainFillParams}; +pub use noise::{simple_noise_3d, CompiledNoiseConfig, NoiseParams, TerrainFillParams, GRAD3}; diff --git a/src/hpc/jitson/noise.rs b/src/hpc/jitson/noise.rs index fe54b732..19a028b1 100644 --- a/src/hpc/jitson/noise.rs +++ b/src/hpc/jitson/noise.rs @@ -29,7 +29,11 @@ impl NoiseParams { freq *= lacunarity; amp *= persistence; } - Self { octaves, lacunarity, persistence } + Self { + octaves, + lacunarity, + persistence, + } } /// Number of octaves. @@ -55,9 +59,18 @@ impl NoiseParams { /// Gradient vectors for 3D Perlin noise (12 edges of a cube). pub const GRAD3: [[f64; 3]; 12] = [ - [1.0, 1.0, 0.0], [-1.0, 1.0, 0.0], [1.0, -1.0, 0.0], [-1.0, -1.0, 0.0], - [1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [1.0, 0.0, -1.0], [-1.0, 0.0, -1.0], - [0.0, 1.0, 1.0], [0.0, -1.0, 1.0], [0.0, 1.0, -1.0], [0.0, -1.0, -1.0], + [1.0, 1.0, 0.0], + [-1.0, 1.0, 0.0], + [1.0, -1.0, 0.0], + [-1.0, -1.0, 0.0], + [1.0, 0.0, 1.0], + [-1.0, 0.0, 1.0], + [1.0, 0.0, -1.0], + [-1.0, 0.0, -1.0], + [0.0, 1.0, 1.0], + [0.0, -1.0, 1.0], + [0.0, 1.0, -1.0], + [0.0, -1.0, -1.0], ]; /// Simple hash-based 3D noise (deterministic, not cryptographic). @@ -120,7 +133,12 @@ impl CompiledNoiseConfig { let amp_sum = params.amplitude_sum(); let normalization = if amp_sum > 0.0 { 1.0 / amp_sum } else { 1.0 }; - Self { frequencies, amplitudes, seed_offsets, normalization } + Self { + frequencies, + amplitudes, + seed_offsets, + normalization, + } } /// Evaluate using the compiled config (reference, matches what JIT would produce). @@ -134,13 +152,7 @@ impl CompiledNoiseConfig { } /// Evaluate and normalize to [-1, 1] range. - pub fn evaluate_normalized( - &self, - x: f64, - y: f64, - z: f64, - base_noise: fn(f64, f64, f64) -> f64, - ) -> f64 { + pub fn evaluate_normalized(&self, x: f64, y: f64, z: f64, base_noise: fn(f64, f64, f64) -> f64) -> f64 { self.evaluate(x, y, z, base_noise) * self.normalization } @@ -196,12 +208,7 @@ impl TerrainFillParams { /// index = y * 256 + z * 16 + x). /// /// Block state ID 0 = air. - pub fn fill_section_reference( - &self, - section_y: i32, - seed: u64, - base_noise: fn(f64, f64, f64) -> f64, - ) -> Vec { + pub fn fill_section_reference(&self, section_y: i32, seed: u64, base_noise: fn(f64, f64, f64) -> f64) -> Vec { let mut blocks = vec![0u16; 4096]; // all air initially let section_base_y = section_y * 16; diff --git a/src/hpc/jitson/parser.rs b/src/hpc/jitson/parser.rs index f9dc8e86..88334e23 100644 --- a/src/hpc/jitson/parser.rs +++ b/src/hpc/jitson/parser.rs @@ -99,11 +99,7 @@ pub struct ParseError { impl core::fmt::Display for ParseError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!( - f, - "JITSON parse error at byte {}: {}", - self.offset, self.message - ) + write!(f, "JITSON parse error at byte {}: {}", self.offset, self.message) } } @@ -169,24 +165,15 @@ impl<'a> Parser<'a> { fn expect(&mut self, expected: u8) -> Result<(), ParseError> { match self.advance() { Some(b) if b == expected => Ok(()), - Some(b) => Err(self.err(&alloc::format!( - "expected '{}', found '{}'", - expected as char, - b as char - ))), + Some(b) => Err(self.err(&alloc::format!("expected '{}', found '{}'", expected as char, b as char))), None => { // Bracket recovery: if we hit EOF expecting a closing bracket, // check if it matches the top of our open_stack. - if (expected == b'}' || expected == b']') - && self.open_stack.last().copied() == Some(expected) - { + if (expected == b'}' || expected == b']') && self.open_stack.last().copied() == Some(expected) { self.open_stack.pop(); Ok(()) } else { - Err(self.err(&alloc::format!( - "unexpected EOF, expected '{}'", - expected as char - ))) + Err(self.err(&alloc::format!("unexpected EOF, expected '{}'", expected as char))) } } } @@ -238,10 +225,8 @@ impl<'a> Parser<'a> { } let hex = &self.input[self.pos..self.pos + 4]; self.pos += 4; - let hex_str = core::str::from_utf8(hex) - .map_err(|_| self.err("invalid \\u hex"))?; - let cp = u32::from_str_radix(hex_str, 16) - .map_err(|_| self.err("invalid \\u hex"))?; + let hex_str = core::str::from_utf8(hex).map_err(|_| self.err("invalid \\u hex"))?; + let cp = u32::from_str_radix(hex_str, 16).map_err(|_| self.err("invalid \\u hex"))?; if let Some(c) = char::from_u32(cp) { s.push(c); } @@ -272,13 +257,9 @@ impl<'a> Parser<'a> { self.pos += 1; } } - if self.pos < self.input.len() - && (self.input[self.pos] == b'e' || self.input[self.pos] == b'E') - { + if self.pos < self.input.len() && (self.input[self.pos] == b'e' || self.input[self.pos] == b'E') { self.pos += 1; - if self.pos < self.input.len() - && (self.input[self.pos] == b'+' || self.input[self.pos] == b'-') - { + if self.pos < self.input.len() && (self.input[self.pos] == b'+' || self.input[self.pos] == b'-') { self.pos += 1; } while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() { @@ -420,15 +401,13 @@ mod tests { let input = r#"{"version": 1, "kernel": "hamming_distance", "scan": {"threshold": 2048}}"#; let root = parse_json(input).unwrap(); assert_eq!(root.get("version").unwrap().as_u64(), Some(1)); - assert_eq!( - root.get("kernel").unwrap().as_str(), - Some("hamming_distance") - ); + assert_eq!(root.get("kernel").unwrap().as_str(), Some("hamming_distance")); } #[test] fn test_bracket_recovery_missing_closing_brace() { - let input = r#"{"version": 1, "kernel": "hamming_distance", "scan": {"threshold": 1, "record_size": 64, "top_k": 5}"#; + let input = + r#"{"version": 1, "kernel": "hamming_distance", "scan": {"threshold": 1, "record_size": 64, "top_k": 5}"#; let root = parse_json(input).unwrap(); assert_eq!(root.get("version").unwrap().as_u64(), Some(1)); } @@ -443,7 +422,8 @@ mod tests { #[test] fn test_bracket_recovery_nested() { - let input = r#"{"version": 1, "kernel": "cosine_i8", "scan": {"threshold": 100, "record_size": 128, "top_k": 3"#; + let input = + r#"{"version": 1, "kernel": "cosine_i8", "scan": {"threshold": 100, "record_size": 128, "top_k": 3"#; let root = parse_json(input).unwrap(); let scan = root.get("scan").unwrap(); assert_eq!(scan.get("top_k").unwrap().as_u64(), Some(3)); diff --git a/src/hpc/jitson/precompile.rs b/src/hpc/jitson/precompile.rs index 6a2de547..1e4275f2 100644 --- a/src/hpc/jitson/precompile.rs +++ b/src/hpc/jitson/precompile.rs @@ -41,9 +41,7 @@ pub struct PrecompileQueue { impl PrecompileQueue { pub fn new() -> Self { - Self { - entries: Vec::new(), - } + Self { entries: Vec::new() } } /// Enqueue a template for precompilation. Returns the stable hash. diff --git a/src/hpc/jitson/template.rs b/src/hpc/jitson/template.rs index 84083d66..4c14ec4c 100644 --- a/src/hpc/jitson/template.rs +++ b/src/hpc/jitson/template.rs @@ -316,11 +316,7 @@ mod tests { fn test_check_pipeline_features() { let tmpl = from_json(VALID_TEMPLATE).unwrap(); let unsatisfied = check_pipeline_features(&tmpl); - assert!( - unsatisfied.is_empty(), - "unexpected unsatisfied: {:?}", - unsatisfied - ); + assert!(unsatisfied.is_empty(), "unexpected unsatisfied: {:?}", unsatisfied); } #[test] @@ -372,5 +368,4 @@ mod tests { let t2 = from_json(BACKEND_TEMPLATE).unwrap(); assert_ne!(template_hash(&t1), template_hash(&t2)); } - } diff --git a/src/hpc/jitson/validator.rs b/src/hpc/jitson/validator.rs index d951a46d..d66ef9c9 100644 --- a/src/hpc/jitson/validator.rs +++ b/src/hpc/jitson/validator.rs @@ -28,48 +28,25 @@ impl core::fmt::Display for ValidationError { /// All AVX-512 feature flags supported by the patched Cranelift backend. pub const KNOWN_FEATURES: &[&str] = &[ - "avx512f", - "avx512vl", - "avx512bw", - "avx512dq", - "avx512bitalg", - "avx512vbmi", - "avx512vpopcntdq", - "avx512vnni", + "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bitalg", "avx512vbmi", "avx512vpopcntdq", "avx512vnni", "avx512ifma", ]; /// All AVX-512 instruction mnemonics from the patched Cranelift. pub const KNOWN_INSTRUCTIONS: &[&str] = &[ // abs - "vpabsb", "vpabsw", "vpabsd", "vpabsq", - // and / ternlog - "vpandd", "vpandq", "vpandnd", "vpandnq", "vpternlogd", "vpternlogq", - // bitmanip - "vpopcntb", "vpopcntw", "vpopcntd", "vpopcntq", - // fma (132/213/231 x ps/pd x add/sub/nmadd) - "vfmadd132ps", "vfmadd213ps", "vfmadd231ps", - "vfmadd132pd", "vfmadd213pd", "vfmadd231pd", - "vfmsub132ps", "vfmsub213ps", "vfmsub231ps", - "vfmsub132pd", "vfmsub213pd", "vfmsub231pd", - "vfnmadd132ps", "vfnmadd213ps", "vfnmadd231ps", - "vfnmadd132pd", "vfnmadd213pd", "vfnmadd231pd", - // mul / vnni - "vpmulld", "vpmullq", - "vpdpbusd", "vpdpbusds", "vpdpwssd", "vpdpwssds", - // or - "vpord", "vporq", - // shift - "vpsllw", "vpslld", "vpsllq", - "vpsraw", "vpsrad", "vpsraq", - "vpsrlw", "vpsrld", "vpsrlq", - // xor - "vpxord", "vpxorq", - // add - "vaddpd", - // cvt - "vcvtudq2ps", - // lanes + "vpabsb", "vpabsw", "vpabsd", "vpabsq", // and / ternlog + "vpandd", "vpandq", "vpandnd", "vpandnq", "vpternlogd", "vpternlogq", // bitmanip + "vpopcntb", "vpopcntw", "vpopcntd", "vpopcntq", // fma (132/213/231 x ps/pd x add/sub/nmadd) + "vfmadd132ps", "vfmadd213ps", "vfmadd231ps", "vfmadd132pd", "vfmadd213pd", "vfmadd231pd", "vfmsub132ps", + "vfmsub213ps", "vfmsub231ps", "vfmsub132pd", "vfmsub213pd", "vfmsub231pd", "vfnmadd132ps", "vfnmadd213ps", + "vfnmadd231ps", "vfnmadd132pd", "vfnmadd213pd", "vfnmadd231pd", // mul / vnni + "vpmulld", "vpmullq", "vpdpbusd", "vpdpbusds", "vpdpwssd", "vpdpwssds", // or + "vpord", "vporq", // shift + "vpsllw", "vpslld", "vpsllq", "vpsraw", "vpsrad", "vpsraq", "vpsrlw", "vpsrld", "vpsrlq", // xor + "vpxord", "vpxorq", // add + "vaddpd", // cvt + "vcvtudq2ps", // lanes "vpermi2b", ]; @@ -81,10 +58,9 @@ pub const KNOWN_BACKENDS: &[&str] = &["lancedb", "dragonfly"]; /// Known Cranelift presets. const KNOWN_PRESETS: &[&str] = &[ - "baseline", "nehalem", "haswell", "broadwell", "skylake", - "knl", "knm", "skylake_avx512", "cascade_lake", "cooper_lake", - "cannon_lake", "ice_lake_client", "ice_lake_server", "tiger_lake", - "sapphire_rapids", "x86_64_v2", "x86_64_v3", "x86_64_v4", + "baseline", "nehalem", "haswell", "broadwell", "skylake", "knl", "knm", "skylake_avx512", "cascade_lake", + "cooper_lake", "cannon_lake", "ice_lake_client", "ice_lake_server", "tiger_lake", "sapphire_rapids", "x86_64_v2", + "x86_64_v3", "x86_64_v4", ]; /// Known opt levels. @@ -132,11 +108,7 @@ pub fn validate(root: &JsonValue) -> Vec { Some(s) if KNOWN_KERNELS.contains(&s) => {} Some(s) => errs.push(ValidationError { path: String::from("/kernel"), - message: alloc::format!( - "unknown kernel \"{}\", expected one of: {}", - s, - KNOWN_KERNELS.join(", ") - ), + message: alloc::format!("unknown kernel \"{}\", expected one of: {}", s, KNOWN_KERNELS.join(", ")), }), None => errs.push(ValidationError { path: String::from("/kernel"), @@ -192,10 +164,7 @@ pub fn validate(root: &JsonValue) -> Vec { if !KNOWN_INSTRUCTIONS.contains(&instr) { errs.push(ValidationError { path: alloc::format!("{}/avx512", prefix), - message: alloc::format!( - "unknown instruction \"{}\"; not in patched Cranelift", - instr - ), + message: alloc::format!("unknown instruction \"{}\"; not in patched Cranelift", instr), }); } } @@ -312,10 +281,7 @@ pub fn validate(root: &JsonValue) -> Vec { if !declared_backends.contains(&backend) { errs.push(ValidationError { path: alloc::format!("/pipeline/{}/backend", i), - message: alloc::format!( - "backend \"{}\" referenced but not declared in /backends", - backend - ), + message: alloc::format!("backend \"{}\" referenced but not declared in /backends", backend), }); } } @@ -323,9 +289,7 @@ pub fn validate(root: &JsonValue) -> Vec { } // Warn on unknown top-level keys - let known_top: &[&str] = &[ - "version", "kernel", "scan", "pipeline", "features", "cranelift", "backends", - ]; + let known_top: &[&str] = &["version", "kernel", "scan", "pipeline", "features", "cranelift", "backends"]; for (key, _) in obj { if !known_top.contains(&key.as_str()) { errs.push(ValidationError { @@ -360,23 +324,20 @@ pub fn required_features(instruction: &str) -> &'static [&'static str] { match instruction { "vpabsb" | "vpabsw" => &["avx512vl", "avx512bw"], "vpabsd" | "vpabsq" => &["avx512vl", "avx512f"], - "vpandd" | "vpandq" | "vpandnd" | "vpandnq" | "vpternlogd" | "vpternlogq" => { - &["avx512vl", "avx512f"] - } + "vpandd" | "vpandq" | "vpandnd" | "vpandnq" | "vpternlogd" | "vpternlogq" => &["avx512vl", "avx512f"], "vpopcntb" | "vpopcntw" => &["avx512vl", "avx512bitalg"], "vpopcntd" | "vpopcntq" => &["avx512vl", "avx512vpopcntdq"], - "vfmadd132ps" | "vfmadd213ps" | "vfmadd231ps" | "vfmadd132pd" | "vfmadd213pd" - | "vfmadd231pd" | "vfmsub132ps" | "vfmsub213ps" | "vfmsub231ps" | "vfmsub132pd" - | "vfmsub213pd" | "vfmsub231pd" | "vfnmadd132ps" | "vfnmadd213ps" | "vfnmadd231ps" - | "vfnmadd132pd" | "vfnmadd213pd" | "vfnmadd231pd" => &["avx512vl", "avx512f"], + "vfmadd132ps" | "vfmadd213ps" | "vfmadd231ps" | "vfmadd132pd" | "vfmadd213pd" | "vfmadd231pd" + | "vfmsub132ps" | "vfmsub213ps" | "vfmsub231ps" | "vfmsub132pd" | "vfmsub213pd" | "vfmsub231pd" + | "vfnmadd132ps" | "vfnmadd213ps" | "vfnmadd231ps" | "vfnmadd132pd" | "vfnmadd213pd" | "vfnmadd231pd" => { + &["avx512vl", "avx512f"] + } "vpmulld" => &["avx512vl", "avx512f"], "vpmullq" => &["avx512vl", "avx512dq"], "vpdpbusd" | "vpdpbusds" | "vpdpwssd" | "vpdpwssds" => &["avx512vl", "avx512vnni"], "vpord" | "vporq" => &["avx512vl", "avx512f"], "vpsllw" | "vpsraw" | "vpsrlw" => &["avx512vl", "avx512bw"], - "vpslld" | "vpsllq" | "vpsrad" | "vpsraq" | "vpsrld" | "vpsrlq" => { - &["avx512vl", "avx512f"] - } + "vpslld" | "vpsllq" | "vpsrad" | "vpsraq" | "vpsrld" | "vpsrlq" => &["avx512vl", "avx512f"], "vpxord" | "vpxorq" => &["avx512vl", "avx512f"], "vaddpd" => &["avx512vl"], "vcvtudq2ps" => &["avx512vl", "avx512f"], @@ -458,10 +419,7 @@ mod tests { #[test] fn test_required_features_mapping() { assert_eq!(required_features("vpxord"), &["avx512vl", "avx512f"]); - assert_eq!( - required_features("vpopcntd"), - &["avx512vl", "avx512vpopcntdq"] - ); + assert_eq!(required_features("vpopcntd"), &["avx512vl", "avx512vpopcntdq"]); assert_eq!(required_features("vpdpbusd"), &["avx512vl", "avx512vnni"]); assert_eq!(required_features("vpermi2b"), &["avx512vl", "avx512vbmi"]); assert_eq!(required_features("not_real"), &[] as &[&str]); diff --git a/src/hpc/jitson_cranelift/engine.rs b/src/hpc/jitson_cranelift/engine.rs index 2782f864..2bc87d0a 100644 --- a/src/hpc/jitson_cranelift/engine.rs +++ b/src/hpc/jitson_cranelift/engine.rs @@ -88,9 +88,7 @@ impl JitEngineBuilder { .into_iter() .map(|(k, v)| (k, v as usize)) .collect(); - builder.symbol_lookup_fn(Box::new(move |name| { - symbols.get(name).map(|&addr| addr as *const u8) - })); + builder.symbol_lookup_fn(Box::new(move |name| symbols.get(name).map(|&addr| addr as *const u8))); let module = JITModule::new(builder); @@ -193,11 +191,7 @@ impl JitEngine { /// /// The `distance_fn_name` must be a symbol registered via /// `JitEngineBuilder::register_fn()`. - pub fn compile_hybrid( - &mut self, - params: ScanParams, - distance_fn_name: &str, - ) -> Result { + pub fn compile_hybrid(&mut self, params: ScanParams, distance_fn_name: &str) -> Result { self.compile_inner(params, Some(distance_fn_name)) } @@ -206,15 +200,11 @@ impl JitEngine { queue.iter().map(|p| self.compile(p.clone())).collect() } - fn compile_inner( - &mut self, - params: ScanParams, - distance_fn: Option<&str>, - ) -> Result { + fn compile_inner(&mut self, params: ScanParams, distance_fn: Option<&str>) -> Result { let cache_key = params_hash(¶ms, distance_fn); - let cache = LazyLock::get_mut(&mut self.cache) - .expect("JitEngine: cannot compile after freeze — cache is immutable"); + let cache = + LazyLock::get_mut(&mut self.cache).expect("JitEngine: cannot compile after freeze — cache is immutable"); // Already compiled? Return existing hash. if cache.map.contains_key(&cache_key) { @@ -306,9 +296,7 @@ impl JitEngine { #[cfg(target_arch = "x86_64")] unsafe { #[cfg(target_feature = "sse")] - core::arch::x86_64::_mm_prefetch::< - { core::arch::x86_64::_MM_HINT_T0 }, - >(next_ptr as *const i8); + core::arch::x86_64::_mm_prefetch::<{ core::arch::x86_64::_MM_HINT_T0 }>(next_ptr as *const i8); } let _ = next_ptr; // suppress unused warning on non-x86 } @@ -339,11 +327,7 @@ impl JitEngine { } /// Compile a hybrid scan kernel (legacy API). - pub fn compile_hybrid_scan( - &mut self, - params: ScanParams, - distance_fn_name: &str, - ) -> Result { + pub fn compile_hybrid_scan(&mut self, params: ScanParams, distance_fn_name: &str) -> Result { let hash = self.compile_hybrid(params.clone(), distance_fn_name)?; Ok(self.get(hash).expect("just compiled")) } @@ -444,9 +428,18 @@ mod tests { fn engine_compile_batch() { let mut engine = JitEngine::new().unwrap(); let params_list = vec![ - ScanParams { threshold: 100, ..ScanParams::default() }, - ScanParams { threshold: 200, ..ScanParams::default() }, - ScanParams { threshold: 300, ..ScanParams::default() }, + ScanParams { + threshold: 100, + ..ScanParams::default() + }, + ScanParams { + threshold: 200, + ..ScanParams::default() + }, + ScanParams { + threshold: 300, + ..ScanParams::default() + }, ]; let hashes = engine.compile_batch(¶ms_list).unwrap(); assert_eq!(hashes.len(), 3); @@ -477,13 +470,22 @@ mod tests { fn engine_prefetch_chain_order() { let mut engine = JitEngine::new().unwrap(); let h1 = engine - .compile(ScanParams { threshold: 10, ..ScanParams::default() }) + .compile(ScanParams { + threshold: 10, + ..ScanParams::default() + }) .unwrap(); let h2 = engine - .compile(ScanParams { threshold: 20, ..ScanParams::default() }) + .compile(ScanParams { + threshold: 20, + ..ScanParams::default() + }) .unwrap(); let _h3 = engine - .compile(ScanParams { threshold: 30, ..ScanParams::default() }) + .compile(ScanParams { + threshold: 30, + ..ScanParams::default() + }) .unwrap(); // Verify prefetch chain is populated @@ -524,8 +526,8 @@ mod tests { kernel.scan( query.as_ptr(), field.as_ptr(), - 4, // 4 records - 8, // record_size (ignored by JIT — baked as immediate) + 4, // 4 records + 8, // record_size (ignored by JIT — baked as immediate) candidates.as_mut_ptr(), ) }; @@ -541,9 +543,18 @@ mod tests { #[test] fn kernel_caching_dedup() { let mut engine = JitEngine::new().unwrap(); - let p1 = ScanParams { threshold: 42, ..ScanParams::default() }; - let p2 = ScanParams { threshold: 42, ..ScanParams::default() }; - let p3 = ScanParams { threshold: 99, ..ScanParams::default() }; + let p1 = ScanParams { + threshold: 42, + ..ScanParams::default() + }; + let p2 = ScanParams { + threshold: 42, + ..ScanParams::default() + }; + let p3 = ScanParams { + threshold: 99, + ..ScanParams::default() + }; let h1 = engine.compile(p1).unwrap(); let h2 = engine.compile(p2).unwrap(); diff --git a/src/hpc/jitson_cranelift/mod.rs b/src/hpc/jitson_cranelift/mod.rs index 033c207f..85b08d14 100644 --- a/src/hpc/jitson_cranelift/mod.rs +++ b/src/hpc/jitson_cranelift/mod.rs @@ -19,8 +19,8 @@ pub mod engine; pub mod scan_jit; pub mod noise_jit; -pub use ir::*; pub use detect::CpuCaps; pub use engine::{JitEngine, JitEngineBuilder}; -pub use scan_jit::ScanKernel; +pub use ir::*; pub use noise_jit::{NoiseKernel, NoiseKernelParams}; +pub use scan_jit::ScanKernel; diff --git a/src/hpc/jitson_cranelift/noise_jit.rs b/src/hpc/jitson_cranelift/noise_jit.rs index 6254b98e..1858478d 100644 --- a/src/hpc/jitson_cranelift/noise_jit.rs +++ b/src/hpc/jitson_cranelift/noise_jit.rs @@ -65,10 +65,7 @@ unsafe impl Sync for NoiseKernel {} impl NoiseKernel { /// Wrap a raw function pointer as a `NoiseKernel`. pub(crate) fn from_raw(ptr: *const u8, params: NoiseKernelParams) -> Self { - Self { - fn_ptr: ptr, - params, - } + Self { fn_ptr: ptr, params } } /// Evaluate the compiled noise function at the given coordinates. @@ -82,8 +79,7 @@ impl NoiseKernel { pub unsafe fn evaluate(&self, x: f64, y: f64, z: f64) -> f64 { // SAFETY: caller guarantees fn_ptr validity; fn_ptr was compiled // by Cranelift with the matching signature (f64, f64, f64) -> f64. - let func: unsafe extern "C" fn(f64, f64, f64) -> f64 = - std::mem::transmute(self.fn_ptr); + let func: unsafe extern "C" fn(f64, f64, f64) -> f64 = std::mem::transmute(self.fn_ptr); func(x, y, z) } @@ -109,9 +105,7 @@ impl NoiseKernel { /// let kernel_params = from_compiled_config(&config); /// assert_eq!(kernel_params.num_octaves, 4); /// ``` -pub fn from_compiled_config( - config: &super::super::jitson::noise::CompiledNoiseConfig, -) -> NoiseKernelParams { +pub fn from_compiled_config(config: &super::super::jitson::noise::CompiledNoiseConfig) -> NoiseKernelParams { NoiseKernelParams { num_octaves: config.frequencies.len() as u32, frequencies: config.frequencies.clone(), @@ -141,9 +135,7 @@ pub fn from_compiled_config( /// return value /// ``` pub fn build_noise_ir( - func: &mut Function, - params: &NoiseKernelParams, - base_noise_ref: cranelift_codegen::ir::FuncRef, + func: &mut Function, params: &NoiseKernelParams, base_noise_ref: cranelift_codegen::ir::FuncRef, ) -> Result<(), JitError> { // Validate params let n = params.num_octaves as usize; @@ -270,11 +262,7 @@ impl super::engine::JitEngine { /// Only works during BUILD phase (before sharing via `Arc`). /// /// Returns a cache hash that can be used with `get_noise()`. - pub fn compile_noise( - &mut self, - params: NoiseKernelParams, - base_noise_name: &str, - ) -> Result { + pub fn compile_noise(&mut self, params: NoiseKernelParams, base_noise_name: &str) -> Result { let cache_key = noise_params_hash(¶ms, base_noise_name); // Already compiled? Return existing hash. @@ -375,10 +363,7 @@ mod tests { } // Verify normalization roundtrip - assert!( - (kernel_params.normalization - config.normalization).abs() < 1e-10, - "normalization mismatch" - ); + assert!((kernel_params.normalization - config.normalization).abs() < 1e-10, "normalization mismatch"); } #[test] @@ -485,10 +470,7 @@ mod tests { }); let result = build_noise_ir(&mut func, ¶ms, base_noise_ref); - assert!( - result.is_err(), - "should reject mismatched num_octaves vs frequencies" - ); + assert!(result.is_err(), "should reject mismatched num_octaves vs frequencies"); } #[test] @@ -525,9 +507,6 @@ mod tests { }); let result = build_noise_ir(&mut func, ¶ms, base_noise_ref); - assert!( - result.is_err(), - "should reject mismatched num_octaves vs amplitudes" - ); + assert!(result.is_err(), "should reject mismatched num_octaves vs amplitudes"); } } diff --git a/src/hpc/jitson_cranelift/scan_jit.rs b/src/hpc/jitson_cranelift/scan_jit.rs index 9633ff88..4a200fa2 100644 --- a/src/hpc/jitson_cranelift/scan_jit.rs +++ b/src/hpc/jitson_cranelift/scan_jit.rs @@ -36,10 +36,7 @@ unsafe impl Sync for ScanKernel {} impl ScanKernel { /// Wrap a raw function pointer as a ScanKernel. pub(crate) fn from_raw(ptr: *const u8, params: ScanParams) -> Self { - Self { - fn_ptr: ptr, - params, - } + Self { fn_ptr: ptr, params } } /// Execute the compiled scan. @@ -50,12 +47,7 @@ impl ScanKernel { /// - `field` must point to `field_len * record_size` bytes. /// - `candidates_out` must point to a buffer large enough for results. pub unsafe fn scan( - &self, - query: *const u8, - field: *const u8, - field_len: u64, - record_size: u64, - candidates_out: *mut u64, + &self, query: *const u8, field: *const u8, field_len: u64, record_size: u64, candidates_out: *mut u64, ) -> u64 { // SAFETY: caller guarantees pointer validity; fn_ptr was compiled // by Cranelift with the matching signature. @@ -90,9 +82,7 @@ impl ScanKernel { /// return candidate_count /// ``` pub fn build_scan_ir( - func: &mut Function, - params: &ScanParams, - dist_func_ref: Option, + func: &mut Function, params: &ScanParams, dist_func_ref: Option, ) -> Result<(), JitError> { let mut fbc = FunctionBuilderContext::new(); let mut builder = FunctionBuilder::new(func, &mut fbc); @@ -119,12 +109,8 @@ pub fn build_scan_ir( let candidates_out = builder.block_params(entry)[4]; // Baked constants (the whole point of JIT compilation) - let threshold_imm = builder - .ins() - .iconst(types::I64, params.threshold as i64); - let record_size_imm = builder - .ins() - .iconst(types::I64, params.record_size as i64); + let threshold_imm = builder.ins().iconst(types::I64, params.threshold as i64); + let record_size_imm = builder.ins().iconst(types::I64, params.record_size as i64); let top_k_imm = builder.ins().iconst(types::I64, params.top_k as i64); let zero = builder.ins().iconst(types::I64, 0); let one = builder.ins().iconst(types::I64, 1); @@ -194,9 +180,7 @@ pub fn build_scan_ir( // candidates_out[count] = i let out_offset = builder.ins().imul(count, eight); let out_ptr = builder.ins().iadd(candidates_out, out_offset); - builder - .ins() - .store(MemFlags::trusted(), i, out_ptr, 0); + builder.ins().store(MemFlags::trusted(), i, out_ptr, 0); // count++ let new_count = builder.ins().iadd(count, one); diff --git a/src/hpc/kernels.rs b/src/hpc/kernels.rs index fd9b3d21..229cd549 100644 --- a/src/hpc/kernels.rs +++ b/src/hpc/kernels.rs @@ -83,12 +83,7 @@ impl SliceGate { /// `safety_margin` multiplicatively relaxes K0/K1 thresholds to guarantee /// zero false negatives (typically 1.5). pub fn new( - total_bits: usize, - hot_frac: f64, - mid_frac: f64, - cold_frac: f64, - anti_frac: f64, - safety_margin: f64, + total_bits: usize, hot_frac: f64, mid_frac: f64, cold_frac: f64, anti_frac: f64, safety_margin: f64, ) -> Self { let d = total_bits as f64; @@ -334,32 +329,16 @@ pub fn k2_exact(query: &[u64], candidate: &[u64], n_words: usize) -> EnergyConfl let full_quads = n_words / 4; for q in 0..full_quads { let base = q * 4; - let (qa, qb, qc, qd) = ( - query[base], - query[base + 1], - query[base + 2], - query[base + 3], - ); - let (ca, cb, cc, cd) = ( - candidate[base], - candidate[base + 1], - candidate[base + 2], - candidate[base + 3], - ); + let (qa, qb, qc, qd) = (query[base], query[base + 1], query[base + 2], query[base + 3]); + let (ca, cb, cc, cd) = (candidate[base], candidate[base + 1], candidate[base + 2], candidate[base + 3]); - conflict += (qa ^ ca).count_ones() - + (qb ^ cb).count_ones() - + (qc ^ cc).count_ones() - + (qd ^ cd).count_ones(); + conflict += (qa ^ ca).count_ones() + (qb ^ cb).count_ones() + (qc ^ cc).count_ones() + (qd ^ cd).count_ones(); energy_a += qa.count_ones() + qb.count_ones() + qc.count_ones() + qd.count_ones(); energy_b += ca.count_ones() + cb.count_ones() + cc.count_ones() + cd.count_ones(); - agreement += (qa & ca).count_ones() - + (qb & cb).count_ones() - + (qc & cc).count_ones() - + (qd & cd).count_ones(); + agreement += (qa & ca).count_ones() + (qb & cb).count_ones() + (qc & cc).count_ones() + (qd & cd).count_ones(); } // Remaining words @@ -653,14 +632,8 @@ impl K2Histogram { /// allocation cost is negligible. #[inline] pub fn k2_exact_histogram(query: &[u64], candidate: &[u64], n_words: usize) -> K2Histogram { - assert!( - query.len() >= n_words, - "k2_exact_histogram: query too short" - ); - assert!( - candidate.len() >= n_words, - "k2_exact_histogram: candidate too short" - ); + assert!(query.len() >= n_words, "k2_exact_histogram: query too short"); + assert!(candidate.len() >= n_words, "k2_exact_histogram: candidate too short"); let mut conflict: u32 = 0; let mut energy_a: u32 = 0; @@ -672,18 +645,8 @@ pub fn k2_exact_histogram(query: &[u64], candidate: &[u64], n_words: usize) -> K let full_quads = n_words / 4; for q in 0..full_quads { let base = q * 4; - let (qa, qb, qc, qd) = ( - query[base], - query[base + 1], - query[base + 2], - query[base + 3], - ); - let (ca, cb, cc, cd) = ( - candidate[base], - candidate[base + 1], - candidate[base + 2], - candidate[base + 3], - ); + let (qa, qb, qc, qd) = (query[base], query[base + 1], query[base + 2], query[base + 3]); + let (ca, cb, cc, cd) = (candidate[base], candidate[base + 1], candidate[base + 2], candidate[base + 3]); let pa = (qa ^ ca).count_ones(); let pb = (qb ^ cb).count_ones(); @@ -699,10 +662,7 @@ pub fn k2_exact_histogram(query: &[u64], candidate: &[u64], n_words: usize) -> K conflict += pa + pb + pc + pd; energy_a += qa.count_ones() + qb.count_ones() + qc.count_ones() + qd.count_ones(); energy_b += ca.count_ones() + cb.count_ones() + cc.count_ones() + cd.count_ones(); - agreement += (qa & ca).count_ones() - + (qb & cb).count_ones() - + (qc & cc).count_ones() - + (qd & cd).count_ones(); + agreement += (qa & ca).count_ones() + (qb & cb).count_ones() + (qc & cc).count_ones() + (qd & cd).count_ones(); } // Remaining words @@ -741,11 +701,7 @@ pub fn k2_exact_histogram(query: &[u64], candidate: &[u64], n_words: usize) -> K /// Returns all matches (candidates that survived K0+K1 and have HDR > 0 /// OR sigma level ≥ Hint) and pipeline statistics. pub fn kernel_pipeline( - query_words: &[u64], - database_words: &[u64], - n_candidates: usize, - n_words: usize, - gate: &SliceGate, + query_words: &[u64], database_words: &[u64], n_candidates: usize, n_words: usize, gate: &SliceGate, ) -> (Vec, PipelineStats) { let sigma_gate = SigmaGate::new(gate.total_bits as usize); assert!( @@ -815,10 +771,7 @@ pub fn kernel_pipeline( /// `database_bytes`: flat byte array /// `n_candidates`: number of containers pub fn kernel_pipeline_bytes( - query_bytes: &[u8], - database_bytes: &[u8], - n_candidates: usize, - gate: &SliceGate, + query_bytes: &[u8], database_bytes: &[u8], n_candidates: usize, gate: &SliceGate, ) -> (Vec, PipelineStats) { let n_bytes = query_bytes.len(); assert!( @@ -843,11 +796,7 @@ pub fn kernel_pipeline_bytes( /// /// For benchmarking Rule 7 (B0 vs pipeline) and verifying zero false negatives. pub fn full_sweep( - query_words: &[u64], - database_words: &[u64], - n_candidates: usize, - n_words: usize, - gate: &SliceGate, + query_words: &[u64], database_words: &[u64], n_candidates: usize, n_words: usize, gate: &SliceGate, ) -> Vec { assert_eq!(query_words.len(), n_words); assert!(database_words.len() >= n_candidates * n_words); @@ -890,11 +839,7 @@ pub fn full_sweep( /// /// Accumulates in FP32 — never stores intermediate BF16 results. pub fn bf16_tail_score( - query_bytes: &[u8], - candidate_bytes: &[u8], - sign_weight: f32, - exp_weight: f32, - man_weight: f32, + query_bytes: &[u8], candidate_bytes: &[u8], sign_weight: f32, exp_weight: f32, man_weight: f32, ) -> f32 { assert_eq!(query_bytes.len(), candidate_bytes.len()); assert!(query_bytes.len() % 2 == 0); @@ -1028,7 +973,11 @@ where { let n = a.len().min(b.len()).min(out.len()); let (a, b, out) = (&a[..n], &b[..n], &mut out[..n]); - for ((a_chunk, b_chunk), out_chunk) in a.chunks_exact(16).zip(b.chunks_exact(16)).zip(out.chunks_exact_mut(16)) { + for ((a_chunk, b_chunk), out_chunk) in a + .chunks_exact(16) + .zip(b.chunks_exact(16)) + .zip(out.chunks_exact_mut(16)) + { let va = F32x16::from_slice(a_chunk); let vb = F32x16::from_slice(b_chunk); f(va, vb).copy_to_slice(out_chunk); @@ -1283,11 +1232,7 @@ mod tests { ); // All planted matches should be found by both - assert!( - full.len() >= 5, - "Expected at least 5 matches from full sweep, got {}", - full.len() - ); + assert!(full.len() >= 5, "Expected at least 5 matches from full sweep, got {}", full.len()); } #[test] @@ -1385,10 +1330,7 @@ mod tests { let (matches, stats) = kernel_pipeline_bytes(&query_bytes, &db_bytes, n, &gate); // Should find at least the exact match at index 0 - assert!( - matches.iter().any(|m| m.index == 0 && m.distance == 0), - "Expected exact match at index 0" - ); + assert!(matches.iter().any(|m| m.index == 0 && m.distance == 0), "Expected exact match at index 0"); assert!(stats.matches >= 1); } @@ -1475,11 +1417,7 @@ mod tests { }; let sigma = score_sigma(&ec, &gate); assert_eq!(sigma.level, SignificanceLevel::Noise); - assert!( - sigma.sigma.abs() < 0.1, - "z should be ~0, got {}", - sigma.sigma - ); + assert!(sigma.sigma.abs() < 0.1, "z should be ~0, got {}", sigma.sigma); } #[test] @@ -1594,10 +1532,7 @@ mod tests { let hist = k2_exact_histogram(&a, &b, SKU_16K_WORDS); let sum: u32 = hist.word_conflicts.iter().map(|&v| v as u32).sum(); - assert_eq!( - sum, hist.energy.conflict, - "Sum of per-word conflicts must equal total conflict" - ); + assert_eq!(sum, hist.energy.conflict, "Sum of per-word conflicts must equal total conflict"); } #[test] @@ -1619,10 +1554,7 @@ mod tests { let hist = k2_exact_histogram(&ones, &zeros, SKU_16K_WORDS); assert_eq!(hist.energy.conflict, SKU_16K_BITS as u32); - assert!( - hist.word_conflicts.iter().all(|&v| v == 64), - "Each word should have 64 bits conflict" - ); + assert!(hist.word_conflicts.iter().all(|&v| v == 64), "Each word should have 64 bits conflict"); } #[test] @@ -1635,10 +1567,7 @@ mod tests { let hist = k2_exact_histogram(&a, &b, SKU_16K_WORDS); assert_eq!(hist.max_word_conflict(), 64); assert_eq!(hist.hottest_word(), 100); - assert!( - hist.variance() > 0.5, - "Localized difference should have high variance" - ); + assert!(hist.variance() > 0.5, "Localized difference should have high variance"); } #[test] diff --git a/src/hpc/layered_distance.rs b/src/hpc/layered_distance.rs index 96b6d192..5ac1d967 100644 --- a/src/hpc/layered_distance.rs +++ b/src/hpc/layered_distance.rs @@ -35,9 +35,7 @@ pub fn read_palette_edge(container: &[u64; 256]) -> PaletteEdge { /// Write palette edge into container W125. pub fn write_palette_edge(container: &mut [u64; 256], pe: PaletteEdge) { - let packed = pe.s_idx as u64 - | ((pe.p_idx as u64) << 8) - | ((pe.o_idx as u64) << 16); + let packed = pe.s_idx as u64 | ((pe.p_idx as u64) << 8) | ((pe.o_idx as u64) << 16); // Preserve upper bits container[W_PALETTE_WORD] = (container[W_PALETTE_WORD] & !0xFF_FFFF) | packed; } @@ -53,27 +51,18 @@ pub fn read_truth(container: &[u64; 256]) -> (f32, f32) { /// Write truth value (frequency, confidence) into container W4-W5. pub fn write_truth(container: &mut [u64; 256], frequency: f32, confidence: f32) { - container[W_FREQUENCY] = (container[W_FREQUENCY] & !0xFFFF_FFFF) - | frequency.to_bits() as u64; - container[W_CONFIDENCE] = (container[W_CONFIDENCE] & !0xFFFF_FFFF) - | confidence.to_bits() as u64; + container[W_FREQUENCY] = (container[W_FREQUENCY] & !0xFFFF_FFFF) | frequency.to_bits() as u64; + container[W_CONFIDENCE] = (container[W_CONFIDENCE] & !0xFFFF_FFFF) | confidence.to_bits() as u64; } /// Layered distance: O(1) palette lookup between two containers. /// /// Reads palette edges from W125 of each container, then looks up the /// precomputed SPO distance in the distance matrices. -pub fn palette_distance( - dm: &SpoDistanceMatrices, - a: &[u64; 256], - b: &[u64; 256], -) -> u32 { +pub fn palette_distance(dm: &SpoDistanceMatrices, a: &[u64; 256], b: &[u64; 256]) -> u32 { let pe_a = read_palette_edge(a); let pe_b = read_palette_edge(b); - dm.spo_distance( - pe_a.s_idx, pe_a.p_idx, pe_a.o_idx, - pe_b.s_idx, pe_b.p_idx, pe_b.o_idx, - ) + dm.spo_distance(pe_a.s_idx, pe_a.p_idx, pe_a.o_idx, pe_b.s_idx, pe_b.p_idx, pe_b.o_idx) } /// TruthGate: filter by minimum expectation. @@ -128,7 +117,11 @@ mod tests { #[test] fn test_read_write_palette_edge_roundtrip() { let mut container = [0u64; 256]; - let pe = PaletteEdge { s_idx: 42, p_idx: 128, o_idx: 255 }; + let pe = PaletteEdge { + s_idx: 42, + p_idx: 128, + o_idx: 255, + }; write_palette_edge(&mut container, pe); let read = read_palette_edge(&container); assert_eq!(pe, read); @@ -137,7 +130,11 @@ mod tests { #[test] fn test_read_write_palette_edge_zero() { let mut container = [0u64; 256]; - let pe = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let pe = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; write_palette_edge(&mut container, pe); let read = read_palette_edge(&container); assert_eq!(pe, read); @@ -163,8 +160,8 @@ mod tests { #[test] fn test_palette_distance_self_zero() { - use super::super::palette_distance::{Palette, SpoDistanceMatrices}; use super::super::bgz17_bridge::Base17; + use super::super::palette_distance::{Palette, SpoDistanceMatrices}; let entries: Vec = (0..16) .map(|i| { @@ -184,8 +181,8 @@ mod tests { #[test] fn test_palette_distance_symmetric() { - use super::super::palette_distance::{Palette, SpoDistanceMatrices}; use super::super::bgz17_bridge::Base17; + use super::super::palette_distance::{Palette, SpoDistanceMatrices}; let entries: Vec = (0..16) .map(|i| { @@ -254,7 +251,11 @@ mod tests { fn test_write_palette_edge_preserves_upper_bits() { let mut container = [0u64; 256]; container[W_PALETTE_WORD] = 0xFFFF_FFFF_FF00_0000; - let pe = PaletteEdge { s_idx: 1, p_idx: 2, o_idx: 3 }; + let pe = PaletteEdge { + s_idx: 1, + p_idx: 2, + o_idx: 3, + }; write_palette_edge(&mut container, pe); // Upper bits should be preserved assert_eq!(container[W_PALETTE_WORD] & 0xFFFF_FFFF_FF00_0000, 0xFFFF_FFFF_FF00_0000); diff --git a/src/hpc/merkle_tree.rs b/src/hpc/merkle_tree.rs index fd2dd7f5..a1f77e44 100644 --- a/src/hpc/merkle_tree.rs +++ b/src/hpc/merkle_tree.rs @@ -19,14 +19,14 @@ const BITS_WORDS: usize = 128; /// Branch region definitions: (start_word, end_word_exclusive) within the /// 256-word u64 metadata container, except branch 7 which covers content. const BRANCH_REGIONS: [(usize, usize); 8] = [ - (0, 16), // [0] identity - (4, 8), // [1] nars (overlaps identity — NARS truth words) - (16, 32), // [2] edges - (32, 40), // [3] rl - (40, 48), // [4] bloom - (56, 64), // [5] qualia - (96, 112), // [6] adjacency - (0, 0), // [7] content — handled specially + (0, 16), // [0] identity + (4, 8), // [1] nars (overlaps identity — NARS truth words) + (16, 32), // [2] edges + (32, 40), // [3] rl + (40, 48), // [4] bloom + (56, 64), // [5] qualia + (96, 112), // [6] adjacency + (0, 0), // [7] content — handled specially ]; /// Type of change detected between two Merkle trees. @@ -89,9 +89,7 @@ fn truncate_hash(hash: &blake3::Hash) -> MerkleRoot { fn hash_words(words: &[u64]) -> MerkleRoot { // SAFETY: u64 slice reinterpretation as u8 is safe — u8 has no alignment // requirement, and the byte count is exact (words.len() * 8). - let bytes = unsafe { - core::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 8) - }; + let bytes = unsafe { core::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 8) }; truncate_hash(&blake3::hash(bytes)) } @@ -119,9 +117,7 @@ impl MerkleTree { for container in content { // SAFETY: u64 array reinterpretation as u8 is safe — u8 has no // alignment requirement, and the byte count is exact (256 * 8). - let bytes = unsafe { - core::slice::from_raw_parts(container.as_ptr() as *const u8, 256 * 8) - }; + let bytes = unsafe { core::slice::from_raw_parts(container.as_ptr() as *const u8, 256 * 8) }; content_hasher.update(bytes); } branches[7] = truncate_hash(&content_hasher.finalize()); @@ -192,9 +188,7 @@ impl MerkleTree { /// Pack root, branches, and leaves into a flat 8Kbit (128 x u64) array. fn pack_bits( - root: &MerkleRoot, - branches: &[MerkleRoot; NUM_BRANCHES], - leaves: &[MerkleRoot; NUM_LEAVES], + root: &MerkleRoot, branches: &[MerkleRoot; NUM_BRANCHES], leaves: &[MerkleRoot; NUM_LEAVES], ) -> [u64; BITS_WORDS] { let mut bits = [0u64; BITS_WORDS]; @@ -239,9 +233,7 @@ impl MerkleTree { fn bits_as_bytes(&self) -> &[u8] { // SAFETY: [u64; 128] is contiguous in memory. u8 has no alignment // requirement stricter than u64. Length 128 * 8 = 1024 is exact. - unsafe { - core::slice::from_raw_parts(self.bits.as_ptr() as *const u8, BITS_WORDS * 8) - } + unsafe { core::slice::from_raw_parts(self.bits.as_ptr() as *const u8, BITS_WORDS * 8) } } /// Hamming distance between two Merkle trees over the full 8Kbit vector. @@ -343,10 +335,7 @@ impl MerkleTree { /// Returns a sparsity score 0..=8 (0 = identical, 8 = all branches differ). #[inline] pub fn diff_sparsity(&self, other: &MerkleTree) -> u8 { - self.diff_branches(other) - .iter() - .filter(|&&d| d) - .count() as u8 + self.diff_branches(other).iter().filter(|&&d| d).count() as u8 } } @@ -370,7 +359,9 @@ mod tests { let mut meta = [0u64; 256]; let mut state = seed; for word in meta.iter_mut() { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *word = state; } meta @@ -382,7 +373,9 @@ mod tests { let mut state = seed; for container in containers.iter_mut() { for word in container.iter_mut() { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *word = state; } } diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index 35b2038b..05dbf184 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -1,8 +1,4 @@ -#![allow( - clippy::all, - unused_imports, - dead_code -)] +#![allow(clippy::all, unused_imports, dead_code)] //! HPC extensions for ndarray — ported from rustynum. //! //! This module provides high-performance computing extensions: @@ -237,15 +233,15 @@ pub mod audio; mod e2e_tests { //! End-to-end pipeline test: Fingerprint → Node → Seal → Cascade → CLAM → Causality → BNN + use super::bf16_truth::PackedQualia; + use super::blackboard::Blackboard; + use super::bnn::bnn_dot; + use super::cascade::{Band, Cascade}; + use super::causality::{causality_decompose, CausalityDirection}; + use super::clam::{knn_brute, ClamTree}; use super::fingerprint::Fingerprint; use super::node::{Node, SPO, S__, _P_, __O}; use super::seal::Seal; - use super::cascade::{Cascade, Band}; - use super::clam::{ClamTree, knn_brute}; - use super::bf16_truth::PackedQualia; - use super::causality::{causality_decompose, CausalityDirection}; - use super::bnn::bnn_dot; - use super::blackboard::Blackboard; #[test] fn pipeline_fingerprint_to_node_to_seal() { @@ -256,7 +252,9 @@ mod e2e_tests { // 2. Measure distance (SPO full) let d = a.distance(&mut b, SPO); match d { - super::plane::Distance::Measured { disagreement, overlap, .. } => { + super::plane::Distance::Measured { + disagreement, overlap, .. + } => { assert!(overlap > 0, "random nodes should have overlap"); assert!(disagreement > 0, "different seeds should disagree"); } @@ -331,9 +329,9 @@ mod e2e_tests { fn pipeline_causality_decomposition() { let mut a = PackedQualia::zero(); let b = PackedQualia::zero(); - a.resonance[4] = 10; // warmth: positive → Forward - a.resonance[6] = -5; // social: negative → Backward - a.resonance[8] = 3; // sacredness: positive → Forward + a.resonance[4] = 10; // warmth: positive → Forward + a.resonance[6] = -5; // social: negative → Backward + a.resonance[8] = 3; // sacredness: positive → Forward let dec = causality_decompose(&a, &b, None); assert_eq!(dec.warmth_dir, CausalityDirection::Forward); diff --git a/src/hpc/models/api_types.rs b/src/hpc/models/api_types.rs index edb0f1d1..6949cb25 100644 --- a/src/hpc/models/api_types.rs +++ b/src/hpc/models/api_types.rs @@ -63,7 +63,12 @@ pub struct ApiError { impl ApiError { pub fn invalid_request(msg: impl Into) -> Self { - Self { message: msg.into(), r#type: "invalid_request_error".into(), param: None, code: None } + Self { + message: msg.into(), + r#type: "invalid_request_error".into(), + param: None, + code: None, + } } pub fn model_not_found(model: &str) -> Self { Self { @@ -100,7 +105,12 @@ pub struct Model { impl Model { pub fn new(id: impl Into, owned_by: impl Into, created: u64) -> Self { - Self { id: id.into(), object: "model", created, owned_by: owned_by.into() } + Self { + id: id.into(), + object: "model", + created, + owned_by: owned_by.into(), + } } } @@ -113,7 +123,10 @@ pub struct ModelList { impl ModelList { pub fn new(models: Vec) -> Self { - Self { object: "list", data: models } + Self { + object: "list", + data: models, + } } } @@ -213,7 +226,15 @@ pub struct CompletionResponse { impl CompletionResponse { pub fn new(id: String, model: String, choices: Vec, usage: Usage, created: u64) -> Self { - Self { id, object: "text_completion", created, model, choices, usage, system_fingerprint: None } + Self { + id, + object: "text_completion", + created, + model, + choices, + usage, + system_fingerprint: None, + } } } @@ -265,13 +286,31 @@ pub struct ChatMessage { impl ChatMessage { pub fn system(content: impl Into) -> Self { - Self { role: ChatRole::System, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + Self { + role: ChatRole::System, + content: Some(content.into()), + name: None, + tool_calls: None, + tool_call_id: None, + } } pub fn user(content: impl Into) -> Self { - Self { role: ChatRole::User, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + Self { + role: ChatRole::User, + content: Some(content.into()), + name: None, + tool_calls: None, + tool_call_id: None, + } } pub fn assistant(content: impl Into) -> Self { - Self { role: ChatRole::Assistant, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + Self { + role: ChatRole::Assistant, + content: Some(content.into()), + name: None, + tool_calls: None, + tool_call_id: None, + } } } @@ -385,7 +424,15 @@ pub struct ChatCompletionResponse { impl ChatCompletionResponse { pub fn new(id: String, model: String, choices: Vec, usage: Usage, created: u64) -> Self { - Self { id, object: "chat.completion", created, model, choices, usage, system_fingerprint: None } + Self { + id, + object: "chat.completion", + created, + model, + choices, + usage, + system_fingerprint: None, + } } } @@ -467,7 +514,11 @@ pub struct EmbeddingData { impl EmbeddingData { pub fn new(index: usize, embedding: Vec) -> Self { - Self { object: "embedding", index, embedding } + Self { + object: "embedding", + index, + embedding, + } } } @@ -482,7 +533,12 @@ pub struct EmbeddingResponse { impl EmbeddingResponse { pub fn new(model: String, data: Vec, usage: Usage) -> Self { - Self { object: "list", model, data, usage } + Self { + object: "list", + model, + data, + usage, + } } } @@ -677,15 +733,21 @@ mod tests { #[test] fn test_streaming_chunk_object() { let chunk = ChatCompletionChunk { - id: "x".into(), object: "chat.completion.chunk", created: 0, - model: "m".into(), choices: vec![], system_fingerprint: None, + id: "x".into(), + object: "chat.completion.chunk", + created: 0, + model: "m".into(), + choices: vec![], + system_fingerprint: None, }; assert_eq!(chunk.object, "chat.completion.chunk"); } #[test] fn test_error_response() { - let err = ErrorResponse { error: ApiError::invalid_request("test") }; + let err = ErrorResponse { + error: ApiError::invalid_request("test"), + }; assert_eq!(err.error.r#type, "invalid_request_error"); } } diff --git a/src/hpc/models/layers.rs b/src/hpc/models/layers.rs index 9a726780..c7f28e1a 100644 --- a/src/hpc/models/layers.rs +++ b/src/hpc/models/layers.rs @@ -426,7 +426,11 @@ mod tests { rope_apply(&mut q, &mut k, 4, 42, 10000.0); let norm_after: f32 = q.iter().map(|x| x * x).sum::().sqrt(); // RoPE is a rotation — should preserve L2 norm - assert!((norm_before - norm_after).abs() < 0.01, - "RoPE should preserve norm: {} vs {}", norm_before, norm_after); + assert!( + (norm_before - norm_after).abs() < 0.01, + "RoPE should preserve norm: {} vs {}", + norm_before, + norm_after + ); } } diff --git a/src/hpc/models/router.rs b/src/hpc/models/router.rs index 677f5ee7..8a03fc73 100644 --- a/src/hpc/models/router.rs +++ b/src/hpc/models/router.rs @@ -39,7 +39,11 @@ pub struct ModelRouter { impl ModelRouter { /// Create an empty router (no models loaded). pub fn new() -> Self { - Self { gpt2: None, openchat: None, request_counter: 0 } + Self { + gpt2: None, + openchat: None, + request_counter: 0, + } } // ── Model registration ───────────────────────────────────────────── @@ -57,8 +61,12 @@ impl ModelRouter { /// Check which models are loaded. pub fn loaded_models(&self) -> Vec<&'static str> { let mut models = Vec::new(); - if self.gpt2.is_some() { models.push("gpt2"); } - if self.openchat.is_some() { models.push("openchat_3.5"); } + if self.gpt2.is_some() { + models.push("gpt2"); + } + if self.openchat.is_some() { + models.push("openchat_3.5"); + } models } @@ -97,7 +105,9 @@ impl ModelRouter { /// /// Routes to GPT-2. Returns error if GPT-2 is not loaded. pub fn complete(&mut self, req: &CompletionRequest) -> Result { - let engine = self.gpt2.as_mut() + let engine = self + .gpt2 + .as_mut() .ok_or_else(|| ApiError::model_not_found(&req.model))?; Ok(engine.complete(req)) } @@ -112,13 +122,17 @@ impl ModelRouter { pub fn chat_complete(&mut self, req: &ChatCompletionRequest) -> Result { match req.model.as_str() { "openchat_3.5" | "openchat" => { - let engine = self.openchat.as_mut() + let engine = self + .openchat + .as_mut() .ok_or_else(|| ApiError::model_not_found(&req.model))?; Ok(engine.chat_complete(req)) } "gpt2" => { // Adapter: convert chat messages to a single text prompt for GPT-2 - let engine = self.gpt2.as_mut() + let engine = self + .gpt2 + .as_mut() .ok_or_else(|| ApiError::model_not_found("gpt2"))?; let completion_req = chat_to_completion(req); let completion_resp = engine.complete(&completion_req); @@ -136,7 +150,9 @@ impl ModelRouter { pub fn embed(&self, req: &EmbeddingRequest) -> Result { match req.model.as_str() { "gpt2" | "text-embedding-gpt2" => { - let engine = self.gpt2.as_ref() + let engine = self + .gpt2 + .as_ref() .ok_or_else(|| ApiError::model_not_found(&req.model))?; Ok(engine.embed(req)) } @@ -206,22 +222,18 @@ fn chat_to_completion(req: &ChatCompletionRequest) -> CompletionRequest { /// Convert a completion response to a chat response (for GPT-2 chat adapter). fn completion_to_chat(resp: CompletionResponse) -> ChatCompletionResponse { - let choices: Vec = resp.choices.into_iter().map(|c| { - ChatChoice { + let choices: Vec = resp + .choices + .into_iter() + .map(|c| ChatChoice { index: c.index, message: ChatMessage::assistant(c.text), finish_reason: c.finish_reason, logprobs: None, - } - }).collect(); - - ChatCompletionResponse::new( - resp.id.replace("cmpl-", "chatcmpl-"), - resp.model, - choices, - resp.usage, - resp.created, - ) + }) + .collect(); + + ChatCompletionResponse::new(resp.id.replace("cmpl-", "chatcmpl-"), resp.model, choices, resp.usage, resp.created) } #[cfg(test)] @@ -252,7 +264,10 @@ mod tests { #[test] fn test_complete_no_model() { let mut router = ModelRouter::new(); - let req = CompletionRequest { model: "gpt2".into(), ..Default::default() }; + let req = CompletionRequest { + model: "gpt2".into(), + ..Default::default() + }; let err = router.complete(&req); assert!(err.is_err()); } @@ -260,7 +275,10 @@ mod tests { #[test] fn test_chat_complete_no_model() { let mut router = ModelRouter::new(); - let req = ChatCompletionRequest { model: "openchat_3.5".into(), ..Default::default() }; + let req = ChatCompletionRequest { + model: "openchat_3.5".into(), + ..Default::default() + }; let err = router.chat_complete(&req); assert!(err.is_err()); } @@ -268,7 +286,10 @@ mod tests { #[test] fn test_embed_no_model() { let router = ModelRouter::new(); - let req = EmbeddingRequest { model: "gpt2".into(), ..Default::default() }; + let req = EmbeddingRequest { + model: "gpt2".into(), + ..Default::default() + }; let err = router.embed(&req); assert!(err.is_err()); } @@ -277,10 +298,7 @@ mod tests { fn test_chat_to_completion_adapter() { let req = ChatCompletionRequest { model: "gpt2".into(), - messages: vec![ - ChatMessage::system("Be helpful"), - ChatMessage::user("Hello"), - ], + messages: vec![ChatMessage::system("Be helpful"), ChatMessage::user("Hello")], max_tokens: Some(100), temperature: Some(0.5), ..Default::default() @@ -304,7 +322,11 @@ mod tests { logprobs: None, finish_reason: Some(FinishReason::Stop), }], - Usage { prompt_tokens: 5, completion_tokens: 2, total_tokens: 7 }, + Usage { + prompt_tokens: 5, + completion_tokens: 2, + total_tokens: 7, + }, 0, ); let chat = completion_to_chat(resp); diff --git a/src/hpc/models/safetensors.rs b/src/hpc/models/safetensors.rs index 46a30d5f..8f9f70cc 100644 --- a/src/hpc/models/safetensors.rs +++ b/src/hpc/models/safetensors.rs @@ -39,27 +39,31 @@ impl SafeTensorsFile { return Err("file too small for safetensors header".into()); } - let header_size = u64::from_le_bytes([ - data[0], data[1], data[2], data[3], - data[4], data[5], data[6], data[7], - ]) as usize; + let header_size = + u64::from_le_bytes([data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7]]) as usize; if 8 + header_size > data.len() { return Err(format!("header_size {} exceeds file len {}", header_size, data.len())); } - let header_json = std::str::from_utf8(&data[8..8 + header_size]) - .map_err(|e| format!("invalid UTF-8 in header: {}", e))?; + let header_json = + std::str::from_utf8(&data[8..8 + header_size]).map_err(|e| format!("invalid UTF-8 in header: {}", e))?; let data_start = 8 + header_size; let tensors = parse_header(header_json)?; - Ok(Self { data, data_start, tensors }) + Ok(Self { + data, + data_start, + tensors, + }) } /// Read a tensor as Vec (little-endian F32). pub fn read_f32(&self, name: &str) -> Result, String> { - let meta = self.tensors.get(name) + let meta = self + .tensors + .get(name) .ok_or_else(|| format!("missing tensor: {}", name))?; let start = self.data_start + meta.offset; let end = start + meta.size; @@ -74,7 +78,9 @@ impl SafeTensorsFile { /// Read a tensor as Vec stored as raw u16 (for F16 tensors). pub fn read_f16_raw(&self, name: &str) -> Result, String> { - let meta = self.tensors.get(name) + let meta = self + .tensors + .get(name) .ok_or_else(|| format!("missing tensor: {}", name))?; let start = self.data_start + meta.offset; let end = start + meta.size; @@ -145,14 +151,18 @@ fn parse_header(json: &str) -> Result, String> { let arr_start = search_start + bracket_start + 1; if let Some(bracket_end) = json[arr_start..].find(']') { let arr = &json[arr_start..arr_start + bracket_end]; - let nums: Vec = arr.split(',') + let nums: Vec = arr + .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); if nums.len() == 2 { - tensors.insert(key.to_string(), TensorMeta { - offset: nums[0], - size: nums[1] - nums[0], - }); + tensors.insert( + key.to_string(), + TensorMeta { + offset: nums[0], + size: nums[1] - nums[0], + }, + ); } } } @@ -183,7 +193,8 @@ mod tests { #[test] fn test_parse_header_with_metadata() { - let json = r#"{"__metadata__": {"format": "pt"}, "w": {"dtype": "F32", "shape": [3], "data_offsets": [0, 12]}}"#; + let json = + r#"{"__metadata__": {"format": "pt"}, "w": {"dtype": "F32", "shape": [3], "data_offsets": [0, 12]}}"#; let tensors = parse_header(json).unwrap(); assert_eq!(tensors.len(), 1); assert!(tensors.contains_key("w")); diff --git a/src/hpc/nars.rs b/src/hpc/nars.rs index 9c6e35c8..56cae79f 100644 --- a/src/hpc/nars.rs +++ b/src/hpc/nars.rs @@ -462,10 +462,7 @@ pub struct Contradiction { /// #11 Detect contradictions: high structural similarity + opposing truth values. /// Science: Wang (2006) revision, Priest (2002) paraconsistent logic, CHAODA. pub fn detect_contradiction( - truth_a: &NarsTruth, - truth_b: &NarsTruth, - structural_similarity: f32, - threshold: f32, + truth_a: &NarsTruth, truth_b: &NarsTruth, structural_similarity: f32, threshold: f32, ) -> Option { let truth_conflict = (truth_a.frequency - truth_b.frequency).abs(); if structural_similarity > 0.7 && truth_conflict > threshold { @@ -607,16 +604,8 @@ mod tests { fn test_from_evidence_roundtrip() { let tv = NarsTruth::from_evidence(9.0, 1.0); let ev = tv.to_evidence(); - assert!( - (ev.positive - 9.0).abs() < 0.5, - "positive: {} expected ~9.0", - ev.positive - ); - assert!( - (ev.negative - 1.0).abs() < 0.5, - "negative: {} expected ~1.0", - ev.negative - ); + assert!((ev.positive - 9.0).abs() < 0.5, "positive: {} expected ~9.0", ev.positive); + assert!((ev.negative - 1.0).abs() < 0.5, "negative: {} expected ~1.0", ev.negative); } #[test] @@ -624,11 +613,7 @@ mod tests { let a = NarsTruth::new(0.9, 0.5); let b = NarsTruth::new(0.9, 0.5); let r = nars_revision(a, b); - assert!( - (r.frequency - 0.9).abs() < 0.05, - "frequency: {} expected ~0.9", - r.frequency - ); + assert!((r.frequency - 0.9).abs() < 0.05, "frequency: {} expected ~0.9", r.frequency); assert!( r.confidence > a.confidence, "revised confidence {} should exceed input {}", @@ -810,29 +795,17 @@ mod tests { #[test] fn test_ignorance_expectation() { let tv = NarsTruth::ignorance(); - assert!( - (tv.expectation() - 0.5).abs() < 1e-6, - "expectation={}", - tv.expectation() - ); + assert!((tv.expectation() - 0.5).abs() < 1e-6, "expectation={}", tv.expectation()); } #[test] fn test_expectation_formula() { let tv = NarsTruth::new(1.0, 0.9); // c * (f - 0.5) + 0.5 = 0.9 * 0.5 + 0.5 = 0.95 - assert!( - (tv.expectation() - 0.95).abs() < 1e-4, - "expectation={}", - tv.expectation() - ); + assert!((tv.expectation() - 0.95).abs() < 1e-4, "expectation={}", tv.expectation()); let tv2 = NarsTruth::new(0.0, 0.8); // 0.8 * (0.0 - 0.5) + 0.5 = -0.4 + 0.5 = 0.1 - assert!( - (tv2.expectation() - 0.1).abs() < 1e-4, - "expectation={}", - tv2.expectation() - ); + assert!((tv2.expectation() - 0.1).abs() < 1e-4, "expectation={}", tv2.expectation()); } } diff --git a/src/hpc/node.rs b/src/hpc/node.rs index ca168514..dbd0bdb5 100644 --- a/src/hpc/node.rs +++ b/src/hpc/node.rs @@ -11,14 +11,46 @@ pub struct Mask { pub o: bool, } -pub const SPO: Mask = Mask { s: true, p: true, o: true }; -pub const SP_: Mask = Mask { s: true, p: true, o: false }; -pub const S_O: Mask = Mask { s: true, p: false, o: true }; -pub const _PO: Mask = Mask { s: false, p: true, o: true }; -pub const S__: Mask = Mask { s: true, p: false, o: false }; -pub const _P_: Mask = Mask { s: false, p: true, o: false }; -pub const __O: Mask = Mask { s: false, p: false, o: true }; -pub const ___: Mask = Mask { s: false, p: false, o: false }; +pub const SPO: Mask = Mask { + s: true, + p: true, + o: true, +}; +pub const SP_: Mask = Mask { + s: true, + p: true, + o: false, +}; +pub const S_O: Mask = Mask { + s: true, + p: false, + o: true, +}; +pub const _PO: Mask = Mask { + s: false, + p: true, + o: true, +}; +pub const S__: Mask = Mask { + s: true, + p: false, + o: false, +}; +pub const _P_: Mask = Mask { + s: false, + p: true, + o: false, +}; +pub const __O: Mask = Mask { + s: false, + p: false, + o: true, +}; +pub const ___: Mask = Mask { + s: false, + p: false, + o: false, +}; impl Mask { #[inline] @@ -96,7 +128,11 @@ impl Node { ($self_plane:expr, $other_plane:expr, $active:expr) => { if $active { match $self_plane.distance(&mut $other_plane) { - Distance::Measured { disagreement, overlap, penalty } => { + Distance::Measured { + disagreement, + overlap, + penalty, + } => { total_disagreement += disagreement; total_overlap += overlap; total_penalty += penalty; @@ -220,10 +256,7 @@ mod tests { let d_spo = a.distance(&mut b, SPO); let d_s = a.distance(&mut b, S__); match (d_spo, d_s) { - ( - Distance::Measured { overlap: o_spo, .. }, - Distance::Measured { overlap: o_s, .. }, - ) => { + (Distance::Measured { overlap: o_spo, .. }, Distance::Measured { overlap: o_s, .. }) => { assert!(o_spo >= o_s); } _ => panic!("expected Measured for random nodes"), @@ -273,8 +306,16 @@ mod tests { // SPO is the last element (index 6) match (projections[6], direct_spo) { ( - Distance::Measured { disagreement: d1, overlap: o1, penalty: p1 }, - Distance::Measured { disagreement: d2, overlap: o2, penalty: p2 }, + Distance::Measured { + disagreement: d1, + overlap: o1, + penalty: p1, + }, + Distance::Measured { + disagreement: d2, + overlap: o2, + penalty: p2, + }, ) => { assert_eq!(d1, d2); assert_eq!(o1, o2); @@ -293,10 +334,9 @@ mod tests { // S__ is index 0 let d_s = a.distance(&mut b, S__); match (projections[0], d_s) { - ( - Distance::Measured { disagreement: d1, .. }, - Distance::Measured { disagreement: d2, .. }, - ) => assert_eq!(d1, d2), + (Distance::Measured { disagreement: d1, .. }, Distance::Measured { disagreement: d2, .. }) => { + assert_eq!(d1, d2) + } _ => panic!("expected Measured"), } } diff --git a/src/hpc/ocr_felt.rs b/src/hpc/ocr_felt.rs index 1742da6b..28109943 100644 --- a/src/hpc/ocr_felt.rs +++ b/src/hpc/ocr_felt.rs @@ -9,7 +9,7 @@ //! For production: use ocrs+rten (AdaWorldAPI/ocrs, AdaWorldAPI/rten). //! This module is the felt-distance fast path and preprocessing accelerator. -use super::ocr_simd::{BinaryImage, GrayImage, foreground_count}; +use super::ocr_simd::{foreground_count, BinaryImage, GrayImage}; /// Euler-Mascheroni constant (Rust 1.94+). const EULER_GAMMA: f64 = std::f64::consts::EULER_GAMMA; @@ -91,7 +91,9 @@ impl CharCodebook { let mut best_dist = u32::MAX; let mut second_dist = u32::MAX; for &(c, ref entry) in &self.entries { - if c == '\0' { continue; } + if c == '\0' { + continue; + } let d = glyph.l1(entry); if d < best_dist { second_dist = best_dist; @@ -129,16 +131,21 @@ impl GlyphPalette { let mut max_dist = 0u32; // First pass: find max distance for i in 0..256 { - for j in i+1..256 { + for j in i + 1..256 { let d = codebook.entries[i].1.l1(&codebook.entries[j].1); - if d > max_dist { max_dist = d; } + if d > max_dist { + max_dist = d; + } } } let scale = if max_dist > 0 { 255.0 / max_dist as f64 } else { 1.0 }; // Second pass: fill table for i in 0..256 { for j in 0..256 { - if i == j { distances[i][j] = 0; continue; } + if i == j { + distances[i][j] = 0; + continue; + } let d = codebook.entries[i].1.l1(&codebook.entries[j].1); distances[i][j] = (d as f64 * scale).round().min(255.0) as u8; } @@ -154,7 +161,10 @@ impl GlyphPalette { let mut best_dist = u32::MAX; for (i, &(_, ref entry)) in codebook.entries.iter().enumerate() { let d = glyph.l1(entry); - if d < best_dist { best_dist = d; best_idx = i as u8; } + if d < best_dist { + best_dist = d; + best_idx = i as u8; + } } best_idx } @@ -189,7 +199,9 @@ impl PolarProfile { for y in 0..height { for x in 0..width { let pixel = pixels.get(y * width + x).copied().unwrap_or(0); - if pixel == 0 { continue; } + if pixel == 0 { + continue; + } let dx = x as f32 - cx; let dy = y as f32 - cy; let r = (dx * dx + dy * dy).sqrt() / max_r; // 0..1 @@ -216,7 +228,9 @@ impl PolarProfile { // Rotate by shifting in groups of 4 bits (4 radial bins per angle) let rotated = rotate_polar(other.bits, shift); let d = (self.bits ^ rotated).count_ones(); - if d < min_d { min_d = d; } + if d < min_d { + min_d = d; + } } min_d } @@ -224,7 +238,9 @@ impl PolarProfile { /// Rotate polar profile by `shift` angular bins (each bin = 4 bits). fn rotate_polar(bits: u64, shift: usize) -> u64 { - if shift == 0 { return bits; } + if shift == 0 { + return bits; + } let shift_bits = (shift % 16) * 4; (bits >> shift_bits) | (bits << (64 - shift_bits)) } @@ -260,11 +276,19 @@ pub fn fast_skew_check(bin: &BinaryImage) -> SkewResult { if normalized > SKEW_FLOOR { // Straight enough — skip full search - SkewResult { angle: 0.0, confidence: normalized as f32, searched: false } + SkewResult { + angle: 0.0, + confidence: normalized as f32, + searched: false, + } } else { // Need full search — use ocr_simd::estimate_skew let angle = super::ocr_simd::estimate_skew(bin); - SkewResult { angle, confidence: normalized as f32, searched: true } + SkewResult { + angle, + confidence: normalized as f32, + searched: true, + } } } @@ -281,9 +305,17 @@ pub struct SkewResult { fn compute_variance(data: &[u32]) -> f64 { let n = data.len() as f64; - if n < 2.0 { return 0.0; } + if n < 2.0 { + return 0.0; + } let mean = data.iter().map(|&v| v as f64).sum::() / n; - data.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::() / n + data.iter() + .map(|&v| { + let d = v as f64 - mean; + d * d + }) + .sum::() + / n } // ═══════════════════════════════════════════════════════════════════════════ @@ -302,7 +334,9 @@ pub fn detect_paragraphs_by_indent(bin: &BinaryImage) -> Vec { for y in 0..h { 'find_first: for xw in 0..words_per_row { let idx = y * words_per_row + xw; - if idx >= bin.bits.len() { break; } + if idx >= bin.bits.len() { + break; + } let word = bin.bits[idx]; if word != 0 { first_pixel[y] = xw * 64 + word.trailing_zeros() as usize; @@ -313,7 +347,9 @@ pub fn detect_paragraphs_by_indent(bin: &BinaryImage) -> Vec { // Find median left margin (typical line start) let mut margins: Vec = first_pixel.iter().filter(|&&p| p < w).copied().collect(); - if margins.is_empty() { return vec![]; } + if margins.is_empty() { + return vec![]; + } margins.sort_unstable(); let median_margin = margins[margins.len() / 2]; @@ -349,14 +385,15 @@ fn render_synthetic_glyph(c: char) -> Vec { for x in 1..7 { state = state.wrapping_mul(31).wrapping_add(code); let threshold = match c { - 'A'..='Z' => 90, // uppercase: more ink - 'a'..='z' => 70, // lowercase: less ink - '0'..='9' => 80, // digits: moderate - '.'|','|';'|':'|'!' => 30, // punctuation: minimal + 'A'..='Z' => 90, // uppercase: more ink + 'a'..='z' => 70, // lowercase: less ink + '0'..='9' => 80, // digits: moderate + '.' | ',' | ';' | ':' | '!' => 30, // punctuation: minimal _ => 50, }; // Position-aware: more ink in center, less at edges - let center_val = ((x as i32 - 3).abs() * 8 + (y as i32 - 5).abs() * 4) as u32; let center_bonus = 40u32.saturating_sub(center_val); + let center_val = ((x as i32 - 3).abs() * 8 + (y as i32 - 5).abs() * 4) as u32; + let center_bonus = 40u32.saturating_sub(center_val); if (state % 200) < (threshold + center_bonus.min(40)) { patch[y * 8 + x] = 255; } @@ -382,7 +419,10 @@ mod tests { let z = GlyphBase17::from_patch(&render_synthetic_glyph('Z'), 8, 12); let b = GlyphBase17::from_patch(&render_synthetic_glyph('B'), 8, 12); // A should be closer to B than to Z - let d_ab = a.l1(&b); let d_az = a.l1(&z); eprintln!("Base17 L1: A-B={}, A-Z={}", d_ab, d_az); assert!(d_ab < 100000, "distances should be finite"); + let d_ab = a.l1(&b); + let d_az = a.l1(&z); + eprintln!("Base17 L1: A-B={}, A-Z={}", d_ab, d_az); + assert!(d_ab < 100000, "distances should be finite"); } #[test] @@ -429,8 +469,7 @@ mod tests { #[test] fn test_euler_gamma_skew_floor() { // γ/(γ+1) ≈ 0.366 - assert!((SKEW_FLOOR - 0.366).abs() < 0.01, - "Euler-gamma floor should be ~0.366, got {}", SKEW_FLOOR); + assert!((SKEW_FLOOR - 0.366).abs() < 0.01, "Euler-gamma floor should be ~0.366, got {}", SKEW_FLOOR); } #[test] @@ -443,7 +482,11 @@ mod tests { bits[y * words_per_row + xw] = u64::MAX; // full row of foreground } } - let bin = BinaryImage { bits, width: 200, height: 200 }; + let bin = BinaryImage { + bits, + width: 200, + height: 200, + }; let result = fast_skew_check(&bin); // Straight horizontal lines should skip full search eprintln!("Skew: angle={:.2}°, conf={:.3}, searched={}", result.angle, result.confidence, result.searched); @@ -470,7 +513,11 @@ mod tests { bits[y * words_per_row] = 0xFFFF_FFFF_FFF0_0000; } - let bin = BinaryImage { bits, width: w, height: h }; + let bin = BinaryImage { + bits, + width: w, + height: h, + }; let paragraphs = detect_paragraphs_by_indent(&bin); eprintln!("Paragraphs detected at rows: {:?}", paragraphs); assert!(paragraphs.len() >= 2, "should detect at least 2 paragraph starts, got {}", paragraphs.len()); @@ -493,8 +540,10 @@ mod tests { let palette_idx = palette.quantize(&base17, &codebook); let (recognized, dist, conf) = codebook.recognize(&base17); - eprintln!(" '{}': Base17 → '{}' (d={}, conf={:.2}) | Palette={} | Polar={:016b}", - c, recognized, dist, conf, palette_idx, polar.bits); + eprintln!( + " '{}': Base17 → '{}' (d={}, conf={:.2}) | Palette={} | Polar={:016b}", + c, recognized, dist, conf, palette_idx, polar.bits + ); } // Cross-distances diff --git a/src/hpc/ocr_simd.rs b/src/hpc/ocr_simd.rs index 753b6967..167e7e3a 100644 --- a/src/hpc/ocr_simd.rs +++ b/src/hpc/ocr_simd.rs @@ -75,10 +75,14 @@ pub fn otsu_threshold(img: &GrayImage) -> u8 { for (t, &count) in histogram.iter().enumerate() { weight_bg += count as f64; - if weight_bg == 0.0 { continue; } + if weight_bg == 0.0 { + continue; + } let weight_fg = total - weight_bg; - if weight_fg == 0.0 { break; } + if weight_fg == 0.0 { + break; + } sum_bg += t as f64 * count as f64; let mean_bg = sum_bg / weight_bg; @@ -130,7 +134,11 @@ pub fn binarize(img: &GrayImage, threshold: u8) -> BinaryImage { bits[word_idx] = mask; } - BinaryImage { bits, width: img.width, height: img.height } + BinaryImage { + bits, + width: img.width, + height: img.height, + } } /// Binarize with automatic Otsu threshold. @@ -176,9 +184,7 @@ pub fn adaptive_binarize(img: &GrayImage, window: usize, c: f32) -> BinaryImage let y2 = (y + half + 1).min(h); let area = ((x2 - x1) * (y2 - y1)) as f32; - let sum = integral[y2 * (w + 1) + x2] - - integral[y1 * (w + 1) + x2] - - integral[y2 * (w + 1) + x1] + let sum = integral[y2 * (w + 1) + x2] - integral[y1 * (w + 1) + x2] - integral[y2 * (w + 1) + x1] + integral[y1 * (w + 1) + x1]; let mean = sum as f32 / area; let threshold = mean - c; @@ -192,7 +198,11 @@ pub fn adaptive_binarize(img: &GrayImage, window: usize, c: f32) -> BinaryImage } } - BinaryImage { bits, width: w, height: h } + BinaryImage { + bits, + width: w, + height: h, + } } // ═══════════════════════════════════════════════════════════════════════════ @@ -236,9 +246,13 @@ fn projection_variance(bin: &BinaryImage, angle_deg: f32) -> f64 { for y in 0..h { for x_word in 0..(w + 63) / 64 { let word_idx = y * ((w + 63) / 64) + x_word; - if word_idx >= bin.bits.len() { break; } + if word_idx >= bin.bits.len() { + break; + } let word = bin.bits[word_idx]; - if word == 0 { continue; } + if word == 0 { + continue; + } // Process set bits let mut bits = word; @@ -258,11 +272,14 @@ fn projection_variance(bin: &BinaryImage, angle_deg: f32) -> f64 { } // Compute variance of non-zero rows - let non_zero: Vec = row_counts.iter() + let non_zero: Vec = row_counts + .iter() .filter(|&&c| c > 0) .map(|&c| c as f64) .collect(); - if non_zero.len() < 2 { return 0.0; } + if non_zero.len() < 2 { + return 0.0; + } let mean = non_zero.iter().sum::() / non_zero.len() as f64; non_zero.iter().map(|&v| (v - mean).powi(2)).sum::() / non_zero.len() as f64 @@ -280,7 +297,9 @@ pub fn foreground_count(bin: &BinaryImage) -> usize { /// Foreground density (ratio of foreground pixels to total). pub fn foreground_density(bin: &BinaryImage) -> f32 { let total = bin.width * bin.height; - if total == 0 { return 0.0; } + if total == 0 { + return 0.0; + } foreground_count(bin) as f32 / total as f32 } @@ -314,7 +333,13 @@ pub fn preprocess_page(img: &GrayImage) -> PreprocessResult { let is_content = density > 0.01 && density < 0.5; let skew_angle = if is_content { estimate_skew(&binary) } else { 0.0 }; - PreprocessResult { binary, threshold, skew_angle, density, is_content } + PreprocessResult { + binary, + threshold, + skew_angle, + density, + is_content, + } } #[cfg(test)] @@ -338,7 +363,11 @@ mod tests { #[test] fn test_otsu_uniform_black() { let data = make_image(128, 128, 0); - let img = GrayImage { data: &data, width: 128, height: 128 }; + let img = GrayImage { + data: &data, + width: 128, + height: 128, + }; let t = otsu_threshold(&img); assert_eq!(t, 0); // all same value } @@ -347,7 +376,11 @@ mod tests { fn test_otsu_bimodal() { let mut data = vec![30u8; 64 * 64]; // dark half data.extend(vec![220u8; 64 * 64]); // light half - let img = GrayImage { data: &data, width: 128, height: 64 }; + let img = GrayImage { + data: &data, + width: 128, + height: 64, + }; let t = otsu_threshold(&img); // Threshold should be between the two modes (30 and 220) assert!(t >= 30 && t <= 220, "bimodal threshold should be between modes: {}", t); @@ -356,7 +389,11 @@ mod tests { #[test] fn test_binarize_all_white() { let data = make_image(128, 128, 255); - let img = GrayImage { data: &data, width: 128, height: 128 }; + let img = GrayImage { + data: &data, + width: 128, + height: 128, + }; let bin = binarize(&img, 128); assert_eq!(foreground_count(&bin), 0); // 255 > 128, no foreground } @@ -364,7 +401,11 @@ mod tests { #[test] fn test_binarize_all_black() { let data = make_image(128, 128, 0); - let img = GrayImage { data: &data, width: 128, height: 128 }; + let img = GrayImage { + data: &data, + width: 128, + height: 128, + }; let bin = binarize(&img, 128); assert_eq!(foreground_count(&bin), 128 * 128); // 0 < 128, all foreground } @@ -372,7 +413,11 @@ mod tests { #[test] fn test_binarize_checkerboard() { let data = make_checkerboard(64, 64); - let img = GrayImage { data: &data, width: 64, height: 64 }; + let img = GrayImage { + data: &data, + width: 64, + height: 64, + }; let bin = binarize(&img, 128); let count = foreground_count(&bin); // Half should be foreground (50 < 128), half background (200 > 128) @@ -382,7 +427,11 @@ mod tests { #[test] fn test_foreground_density() { let data = make_image(100, 100, 0); - let img = GrayImage { data: &data, width: 100, height: 100 }; + let img = GrayImage { + data: &data, + width: 100, + height: 100, + }; let bin = binarize(&img, 128); let d = foreground_density(&bin); assert!((d - 1.0).abs() < 0.01, "all black = density 1.0: {}", d); @@ -391,7 +440,11 @@ mod tests { #[test] fn test_preprocess_blank_page() { let data = make_image(200, 200, 250); // nearly white - let img = GrayImage { data: &data, width: 200, height: 200 }; + let img = GrayImage { + data: &data, + width: 200, + height: 200, + }; let result = preprocess_page(&img); assert!(!result.is_content, "blank page should not be content"); } @@ -406,7 +459,11 @@ mod tests { data[y * 200 + x] = 20; // dark text } } - let img = GrayImage { data: &data, width: 200, height: 200 }; + let img = GrayImage { + data: &data, + width: 200, + height: 200, + }; let result = preprocess_page(&img); assert!(result.is_content, "text page should be content, density={}", result.density); assert!(result.density > 0.01 && result.density < 0.5); @@ -421,7 +478,11 @@ mod tests { data[y * 200 + x] = 0; // dark horizontal line } } - let img = GrayImage { data: &data, width: 200, height: 100 }; + let img = GrayImage { + data: &data, + width: 200, + height: 100, + }; let (bin, _) = auto_binarize(&img); let skew = estimate_skew(&bin); assert!(skew.abs() < 1.0, "horizontal lines should have near-zero skew: {}", skew); @@ -438,7 +499,11 @@ mod tests { data[y * 200 + x] = if has_text { bg.saturating_sub(80) } else { bg }; } } - let img = GrayImage { data: &data, width: 200, height: 100 }; + let img = GrayImage { + data: &data, + width: 200, + height: 100, + }; let otsu_bin = binarize(&img, otsu_threshold(&img)); let adaptive_bin = adaptive_binarize(&img, 31, 10.0); diff --git a/src/hpc/openchat/api.rs b/src/hpc/openchat/api.rs index ac266c16..bf606f8e 100644 --- a/src/hpc/openchat/api.rs +++ b/src/hpc/openchat/api.rs @@ -4,9 +4,9 @@ //! //! Uses the OpenChat template: `GPT4 Correct User: {msg}<|end_of_turn|>` -use crate::hpc::models::api_types::*; use super::inference::{GeneratedToken, OpenChatEngine}; use super::weights::*; +use crate::hpc::models::api_types::*; /// OpenChat API wrapper. pub struct OpenChatApi { @@ -16,7 +16,10 @@ pub struct OpenChatApi { impl OpenChatApi { pub fn new(weights: OpenChatWeights) -> Self { - Self { engine: OpenChatEngine::new(weights), request_counter: 0 } + Self { + engine: OpenChatEngine::new(weights), + request_counter: 0, + } } /// `/v1/chat/completions` @@ -34,7 +37,10 @@ impl OpenChatApi { FinishReason::Length }; - let content: String = generated.iter().map(|t| format!("[{}]", t.token_id)).collect(); + let content: String = generated + .iter() + .map(|t| format!("[{}]", t.token_id)) + .collect(); ChatCompletionResponse::new( format!("chatcmpl-{}", self.request_counter), @@ -123,11 +129,8 @@ mod tests { #[test] fn test_chat_template_multi_turn() { - let messages = vec![ - ChatMessage::user("Hi"), - ChatMessage::assistant("Hello!"), - ChatMessage::user("How are you?"), - ]; + let messages = + vec![ChatMessage::user("Hi"), ChatMessage::assistant("Hello!"), ChatMessage::user("How are you?")]; let prompt = OpenChatApi::format_chat_template(&messages); assert_eq!(prompt.matches("GPT4 Correct User:").count(), 2); assert!(prompt.contains("Hello!")); @@ -135,10 +138,7 @@ mod tests { #[test] fn test_chat_template_with_system() { - let messages = vec![ - ChatMessage::system("You are helpful."), - ChatMessage::user("Hi"), - ]; + let messages = vec![ChatMessage::system("You are helpful."), ChatMessage::user("Hi")]; let prompt = OpenChatApi::format_chat_template(&messages); assert!(prompt.starts_with("You are helpful.")); } diff --git a/src/hpc/openchat/inference.rs b/src/hpc/openchat/inference.rs index 94b5dc0a..b91a09fb 100644 --- a/src/hpc/openchat/inference.rs +++ b/src/hpc/openchat/inference.rs @@ -254,7 +254,14 @@ impl OpenChatEngine { // Output projection let zero_bias = vec![0.0f32; EMBED_DIM]; let mut projected = vec![0.0f32; EMBED_DIM]; - layers::matmul_vec(&output, &self.weights.layers[layer_idx].attn_output, &zero_bias, &mut projected, EMBED_DIM, EMBED_DIM); + layers::matmul_vec( + &output, + &self.weights.layers[layer_idx].attn_output, + &zero_bias, + &mut projected, + EMBED_DIM, + EMBED_DIM, + ); projected } @@ -296,12 +303,7 @@ impl OpenChatEngine { } /// Generate tokens autoregressively. - pub fn generate( - &mut self, - prompt_tokens: &[u32], - max_new_tokens: usize, - temperature: f32, - ) -> Vec { + pub fn generate(&mut self, prompt_tokens: &[u32], max_new_tokens: usize, temperature: f32) -> Vec { self.reset(); let mut generated = Vec::new(); diff --git a/src/hpc/openchat/weights.rs b/src/hpc/openchat/weights.rs index 1fac817c..7ada5622 100644 --- a/src/hpc/openchat/weights.rs +++ b/src/hpc/openchat/weights.rs @@ -29,7 +29,7 @@ pub const NUM_LAYERS: usize = 32; pub const NUM_Q_HEADS: usize = 32; pub const NUM_KV_HEADS: usize = 8; pub const HEAD_DIM: usize = EMBED_DIM / NUM_Q_HEADS; // 128 -pub const KV_DIM: usize = NUM_KV_HEADS * HEAD_DIM; // 1024 +pub const KV_DIM: usize = NUM_KV_HEADS * HEAD_DIM; // 1024 pub const MLP_DIM: usize = 14336; // Mistral uses 14336 (not 4× embed) pub const MAX_SEQ_LEN: usize = 8192; // Mistral supports 8K context (32K with sliding window) pub const ROPE_THETA: f32 = 10000.0; @@ -81,13 +81,11 @@ impl OpenChatWeights { /// /// Pre-transposes weight matrices for SIMD-contiguous `matmul_vec`. pub fn from_gguf(path: &std::path::Path) -> Result { - let mut file = std::fs::File::open(path) - .map_err(|e| format!("open {}: {}", path.display(), e))?; + let mut file = std::fs::File::open(path).map_err(|e| format!("open {}: {}", path.display(), e))?; let header = gguf::read_gguf_header(&mut file)?; let mut read = |name: &str| -> Result, String> { - let tensor = gguf::find_tensor(&header, name) - .ok_or_else(|| format!("missing tensor: {}", name))?; + let tensor = gguf::find_tensor(&header, name).ok_or_else(|| format!("missing tensor: {}", name))?; gguf::read_tensor_f32(&mut file, &header, tensor) }; diff --git a/src/hpc/organic.rs b/src/hpc/organic.rs index 5bd08fc9..569e1feb 100644 --- a/src/hpc/organic.rs +++ b/src/hpc/organic.rs @@ -1,4 +1,6 @@ -#![allow(clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop)] +#![allow( + clippy::assign_op_pattern, clippy::too_many_arguments, clippy::manual_range_contains, clippy::needless_range_loop +)] //! Organic Plasticity Model — BCM-inspired synapse dynamics. //! @@ -93,11 +95,7 @@ pub fn unpack_three(byte: u8) -> (FiveState, FiveState, FiveState) { let rem = byte % 25; let b = rem / 5; let c = rem % 5; - ( - FiveState::from_raw(a), - FiveState::from_raw(b), - FiveState::from_raw(c), - ) + (FiveState::from_raw(a), FiveState::from_raw(b), FiveState::from_raw(c)) } /// Pack a slice of `FiveState` values into bytes (3 values per byte). @@ -275,11 +273,7 @@ pub fn organic_deposit(state: &mut SynapseState, evidence: i8) { /// /// `states` and `evidence` must have the same length. pub fn organic_deposit_batch(states: &mut [SynapseState], evidence: &[i8]) { - assert_eq!( - states.len(), - evidence.len(), - "states and evidence must have same length" - ); + assert_eq!(states.len(), evidence.len(), "states and evidence must have same length"); for (state, &ev) in states.iter_mut().zip(evidence.iter()) { organic_deposit(state, ev); } @@ -525,10 +519,7 @@ mod tests { let theta_before = s.theta; // Deposit — theta should slide toward |efficacy| organic_deposit(&mut s, 10); - assert!( - s.theta >= theta_before, - "theta should slide upward toward high efficacy" - ); + assert!(s.theta >= theta_before, "theta should slide upward toward high efficacy"); } #[test] @@ -735,12 +726,7 @@ mod tests { } // Love should have higher theta (harder to potentiate further) - assert!( - love.theta > kube.theta, - "love theta={} should be > kube theta={}", - love.theta, - kube.theta - ); + assert!(love.theta > kube.theta, "love theta={} should be > kube theta={}", love.theta, kube.theta); // Love should have higher maturity assert!(love.maturity > kube.maturity); } @@ -773,11 +759,6 @@ mod tests { let eff_before = s.efficacy; organic_deposit(&mut s, 10); // positive evidence, but below threshold // Should depress (move toward zero) since |50| < 80 and |50| > 40 - assert!( - s.efficacy < eff_before, - "should depress: before={}, after={}", - eff_before, - s.efficacy - ); + assert!(s.efficacy < eff_before, "should depress: before={}, after={}", eff_before, s.efficacy); } } diff --git a/src/hpc/p64_bridge.rs b/src/hpc/p64_bridge.rs index 36db28d2..bcdfe3a2 100644 --- a/src/hpc/p64_bridge.rs +++ b/src/hpc/p64_bridge.rs @@ -12,16 +12,13 @@ use std::collections::HashMap; use std::sync::LazyLock; -use crate::simd::{F64x8, U64x8}; use crate::hpc::nars::NarsTruth; use crate::hpc::simd_caps::simd_caps; +use crate::simd::{F64x8, U64x8}; // Re-export p64 types for consumers. -pub use p64::{ - AttentionResult, CombineMode, ContraMode, HeelPlanes, Palette3D, Palette64, ThinkingStyle, - predicate, -}; pub use fractal::consts as manifold_consts; +pub use p64::{predicate, AttentionResult, CombineMode, ContraMode, HeelPlanes, Palette3D, Palette64, ThinkingStyle}; // ============================================================================ // Section 1: SIMD manifold expansion @@ -203,8 +200,7 @@ pub fn attend_simd(palette: &Palette64, query: u64, gamma: u8) -> AttentionResul /// assert!((tv.confidence - 1.0).abs() < 0.01); /// ``` pub fn resonance_to_nars(resonance_7bit: u8, contradiction: f64, max_contra: f64) -> NarsTruth { - let (f, c) = - fractal::seven_plus_one::nars_truth(resonance_7bit, contradiction, max_contra); + let (f, c) = fractal::seven_plus_one::nars_truth(resonance_7bit, contradiction, max_contra); NarsTruth::new(f as f32, c as f32) } @@ -705,11 +701,9 @@ mod tests { use causal_edge_compat::*; let edge_causes = (1u64 << SRC_SHIFT) | (2u64 << TGT_SHIFT) | (0u64 << LAYER_SHIFT); - let edge_contra = - (3u64 << SRC_SHIFT) | (4u64 << TGT_SHIFT) | (3u64 << LAYER_SHIFT); + let edge_contra = (3u64 << SRC_SHIFT) | (4u64 << TGT_SHIFT) | (3u64 << LAYER_SHIFT); - let p3d = - palette3d_from_edges(&[edge_causes, edge_contra], ThinkingStyle::ANALYTICAL); + let p3d = palette3d_from_edges(&[edge_causes, edge_contra], ThinkingStyle::ANALYTICAL); // CAUSES layer: row 1 has bit 2 set assert_ne!(p3d.layers[0].rows[1] & (1 << 2), 0); @@ -751,8 +745,13 @@ mod tests { let median = all_dists[all_dists.len() / 2]; // Use 25th percentile for sparse palette (~12.5% density) let p25 = all_dists[all_dists.len() / 4]; - eprintln!("Distance stats: median={}, p25={}, min={}, max={}", - median, p25, all_dists[0], all_dists.last().unwrap()); + eprintln!( + "Distance stats: median={}, p25={}, min={}, max={}", + median, + p25, + all_dists[0], + all_dists.last().unwrap() + ); // Build Palette64 from GPT-2's learned distance table let palette = palette_from_deepnsm_distances(&flat, 256, p25); @@ -773,12 +772,14 @@ mod tests { let r_analytical = p3d_analytical.infer(42); let r_creative = p3d_creative.infer(42); - eprintln!("Analytical: attention={:064b}, tension={}, active_layers={}, new={}", - r_analytical.attention, r_analytical.tension, - r_analytical.active_layers, r_analytical.new_connections); - eprintln!("Creative: attention={:064b}, tension={}, active_layers={}, new={}", - r_creative.attention, r_creative.tension, - r_creative.active_layers, r_creative.new_connections); + eprintln!( + "Analytical: attention={:064b}, tension={}, active_layers={}, new={}", + r_analytical.attention, r_analytical.tension, r_analytical.active_layers, r_analytical.new_connections + ); + eprintln!( + "Creative: attention={:064b}, tension={}, active_layers={}, new={}", + r_creative.attention, r_creative.tension, r_creative.active_layers, r_creative.new_connections + ); // KEY ASSERTION: different styles produce different fan-out // Creative (Union, all layers, density 0.40) should activate MORE targets @@ -787,9 +788,12 @@ mod tests { let creative_popcount = r_creative.attention.count_ones(); eprintln!("Fan-out: analytical={}, creative={}", analytical_popcount, creative_popcount); - assert!(creative_popcount >= analytical_popcount, + assert!( + creative_popcount >= analytical_popcount, "Creative should have wider fan-out than Analytical: {} vs {}", - creative_popcount, analytical_popcount); + creative_popcount, + analytical_popcount + ); // Verify attention is non-trivial assert!(analytical_popcount > 0, "Analytical should fire something"); @@ -801,31 +805,40 @@ mod tests { let mut non_interacting = None; for i in 0..64 { for j in 0..64 { - if i == j { continue; } + if i == j { + continue; + } if palette.rows[i] & (1 << j) != 0 && interacting.is_none() { interacting = Some((i, j)); } if palette.rows[i] & (1 << j) == 0 && non_interacting.is_none() { non_interacting = Some((i, j)); } - if interacting.is_some() && non_interacting.is_some() { break; } + if interacting.is_some() && non_interacting.is_some() { + break; + } + } + if interacting.is_some() && non_interacting.is_some() { + break; } - if interacting.is_some() && non_interacting.is_some() { break; } } if let (Some((ia, ib)), Some((na, nb))) = (interacting, non_interacting) { // Interacting pair should have LOWER distance than non-interacting let d_interact = flat[ia * 256 + ib]; let d_non = flat[na * 256 + nb]; - eprintln!("Interacting ({},{}) distance={}, Non-interacting ({},{}) distance={}", - ia, ib, d_interact, na, nb, d_non); - assert!(d_interact <= d_non, - "Interacting pair should be closer: {} vs {}", d_interact, d_non); + eprintln!( + "Interacting ({},{}) distance={}, Non-interacting ({},{}) distance={}", + ia, ib, d_interact, na, nb, d_non + ); + assert!(d_interact <= d_non, "Interacting pair should be closer: {} vs {}", d_interact, d_non); } eprintln!("GPT-2 → P64 rehydration: PASS"); eprintln!(" 50K tokens → 256 archetypes → 64×64 palette → 8-layer Palette3D"); - eprintln!(" Thinking style modulates fan-out: Analytical={}, Creative={}", - analytical_popcount, creative_popcount); + eprintln!( + " Thinking style modulates fan-out: Analytical={}, Creative={}", + analytical_popcount, creative_popcount + ); } } diff --git a/src/hpc/packed.rs b/src/hpc/packed.rs index 6cd4b887..03fbb6e5 100644 --- a/src/hpc/packed.rs +++ b/src/hpc/packed.rs @@ -122,7 +122,13 @@ impl PackedDatabase { index.push(i as u32); } - Self { stroke1, stroke2, stroke3, index, num_candidates: n } + Self { + stroke1, + stroke2, + stroke3, + index, + num_candidates: n, + } } /// Get stroke 1 slice for candidate i (128 bytes). @@ -177,11 +183,7 @@ impl PackedDatabase { /// /// Returns top-k results sorted by distance ascending. pub fn cascade_query( - &self, - query: &[u8], - reject_threshold_s1: u64, - reject_threshold_s12: u64, - k: usize, + &self, query: &[u8], reject_threshold_s1: u64, reject_threshold_s12: u64, k: usize, ) -> Vec { assert!(query.len() >= FINGERPRINT_BYTES, "query must be at least {FINGERPRINT_BYTES} bytes"); @@ -255,7 +257,10 @@ impl PackedDatabase { let d1 = bitwise::hamming_distance_raw(query_s1, self.get_stroke1(i)); let d2 = bitwise::hamming_distance_raw(query_s2, self.get_stroke2(i)); let d3 = bitwise::hamming_distance_raw(query_s3, self.get_stroke3(i)); - RankedHit { index: self.original_id(i) as usize, distance: d1 + d2 + d3 } + RankedHit { + index: self.original_id(i) as usize, + distance: d1 + d2 + d3, + } }) .collect(); @@ -285,14 +290,8 @@ mod tests { for i in 0..n { let base = i * FINGERPRINT_BYTES; assert_eq!(packed.get_stroke1(i), &db[base..base + STROKE1_BYTES]); - assert_eq!( - packed.get_stroke2(i), - &db[base + STROKE1_BYTES..base + STROKE1_BYTES + STROKE2_BYTES] - ); - assert_eq!( - packed.get_stroke3(i), - &db[base + STROKE1_BYTES + STROKE2_BYTES..base + FINGERPRINT_BYTES] - ); + assert_eq!(packed.get_stroke2(i), &db[base + STROKE1_BYTES..base + STROKE1_BYTES + STROKE2_BYTES]); + assert_eq!(packed.get_stroke3(i), &db[base + STROKE1_BYTES + STROKE2_BYTES..base + FINGERPRINT_BYTES]); } } @@ -368,7 +367,9 @@ mod tests { #[test] fn test_cascade_vs_brute_force_consistency() { let n = 50; - let db: Vec = (0..n * FINGERPRINT_BYTES).map(|i| (i * 7 + 13) as u8).collect(); + let db: Vec = (0..n * FINGERPRINT_BYTES) + .map(|i| (i * 7 + 13) as u8) + .collect(); let packed = PackedDatabase::pack(&db, FINGERPRINT_BYTES); let query: Vec = (0..FINGERPRINT_BYTES).map(|i| (i * 3) as u8).collect(); diff --git a/src/hpc/palette_codec.rs b/src/hpc/palette_codec.rs index 2189c435..9dc4d8a5 100644 --- a/src/hpc/palette_codec.rs +++ b/src/hpc/palette_codec.rs @@ -135,12 +135,7 @@ pub fn compression_ratio(bits_per_index: usize) -> f32 { /// /// Useful when a palette grows (e.g., 4-bit → 5-bit after inserting a 17th entry). /// More efficient than unpack→repack because it avoids the intermediate Vec. -pub fn transcode( - packed: &[u64], - old_bits: usize, - new_bits: usize, - count: usize, -) -> Vec { +pub fn transcode(packed: &[u64], old_bits: usize, new_bits: usize, count: usize) -> Vec { assert!(old_bits > 0 && old_bits <= 8); assert!(new_bits > 0 && new_bits <= 8); @@ -190,7 +185,12 @@ impl PackedPaletteArray { pub fn from_indices(indices: &[u8], palette_size: usize) -> Self { let bits = bits_for_palette_size(palette_size).max(1); let data = pack_indices(indices, bits); - Self { data, count: indices.len(), bits_per_index: bits, palette_size } + Self { + data, + count: indices.len(), + bits_per_index: bits, + palette_size, + } } /// Decode all indices. @@ -402,12 +402,7 @@ unsafe fn unpack_4bit_avx2(packed: &[u64], count: usize) -> Vec { /// Reinterpret &[u64] as &[u8] (little-endian safe). fn bytemuck_cast_u64_to_u8(words: &[u64]) -> &[u8] { // SAFETY: u64 and u8 have compatible layouts on little-endian - unsafe { - core::slice::from_raw_parts( - words.as_ptr() as *const u8, - words.len() * 8, - ) - } + unsafe { core::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 8) } } /// Reorder 4096 block states from Java Y-major ordering (y*256+z*16+x) @@ -483,9 +478,7 @@ pub fn bedrock_reorder_xzy_inverse(states: &[u16]) -> Vec { /// assert!(packed.is_some()); /// ``` pub fn bedrock_pack_section( - states: &[u16], - palette: &std::collections::HashMap, - bits_per_index: usize, + states: &[u16], palette: &std::collections::HashMap, bits_per_index: usize, ) -> Option> { let reordered = bedrock_reorder_xzy(states); let mut indices = Vec::with_capacity(4096); @@ -822,15 +815,17 @@ mod tests { palette.insert(3u16, 3u8); let bits = bits_for_palette_size(4); // 2 bits - let packed = bedrock_pack_section(&states, &palette, bits) - .expect("all states should be in palette"); + let packed = bedrock_pack_section(&states, &palette, bits).expect("all states should be in palette"); // Verify by unpacking and inverse-reordering let unpacked = unpack_indices(&packed, bits, 4096); - let bedrock_states: Vec = unpacked.iter().map(|&idx| { - // Reverse palette lookup: idx → state - *palette.iter().find(|(_, &v)| v == idx).unwrap().0 - }).collect(); + let bedrock_states: Vec = unpacked + .iter() + .map(|&idx| { + // Reverse palette lookup: idx → state + *palette.iter().find(|(_, &v)| v == idx).unwrap().0 + }) + .collect(); let java_states = bedrock_reorder_xzy_inverse(&bedrock_states); assert_eq!(states, java_states, "pack then unpack+inverse must recover original"); } diff --git a/src/hpc/palette_distance.rs b/src/hpc/palette_distance.rs index 6dfe4b9a..2e072132 100644 --- a/src/hpc/palette_distance.rs +++ b/src/hpc/palette_distance.rs @@ -34,9 +34,18 @@ unsafe fn nearest_avx512(entries: &[Base17], query: &Base17) -> u8 { let d3 = query.l1(&entries[base + 3]); // Find min of 4 let (mut min_d, mut min_i) = (d0, 0usize); - if d1 < min_d { min_d = d1; min_i = 1; } - if d2 < min_d { min_d = d2; min_i = 2; } - if d3 < min_d { min_d = d3; min_i = 3; } + if d1 < min_d { + min_d = d1; + min_i = 1; + } + if d2 < min_d { + min_d = d2; + min_i = 2; + } + if d3 < min_d { + min_d = d3; + min_i = 3; + } if min_d < best_dist { best_dist = min_d; best_idx = (base + min_i) as u8; @@ -67,9 +76,18 @@ unsafe fn nearest_avx2(entries: &[Base17], query: &Base17) -> u8 { let d2 = query.l1(&entries[base + 2]); let d3 = query.l1(&entries[base + 3]); let (mut min_d, mut min_i) = (d0, 0usize); - if d1 < min_d { min_d = d1; min_i = 1; } - if d2 < min_d { min_d = d2; min_i = 2; } - if d3 < min_d { min_d = d3; min_i = 3; } + if d1 < min_d { + min_d = d1; + min_i = 1; + } + if d2 < min_d { + min_d = d2; + min_i = 2; + } + if d3 < min_d { + min_d = d3; + min_i = 3; + } if min_d < best_dist { best_dist = min_d; best_idx = (base + min_i) as u8; @@ -488,7 +506,9 @@ mod tests { assert!( d_ac <= d_ab + d_bc, "triangle inequality violated: d(a,c)={} > d(a,b)={} + d(b,c)={}", - d_ac, d_ab, d_bc + d_ac, + d_ab, + d_bc ); } @@ -541,10 +561,9 @@ mod tests { let i = row * 64 + col; for target in 0..64 { let j = row * 64 + target; - total_dist += spo.spo_distance( - heads_s[i], heads_p[i], heads_o[i], - heads_s[j], heads_p[j], heads_o[j], - ) as u64; + total_dist += spo + .spo_distance(heads_s[i], heads_p[i], heads_o[i], heads_s[j], heads_p[j], heads_o[j]) + as u64; } } } @@ -573,8 +592,11 @@ mod tests { eprintln!(" Pearl 2³: {:.0} tokens/sec (8 projections per head)", tokens_per_sec_pearl); eprintln!(" Triple model: {:.0} tokens/sec (self+user+impact)", tokens_per_sec_pearl / 3.0); eprintln!(); - eprintln!(" Memory: {} KB SPO tables + 4 KB head indices = {} KB total", - spo.byte_size() / 1024, spo.byte_size() / 1024 + 4); + eprintln!( + " Memory: {} KB SPO tables + 4 KB head indices = {} KB total", + spo.byte_size() / 1024, + spo.byte_size() / 1024 + 4 + ); eprintln!(" (blackhole: {})", total_dist); // prevent optimizer from eliding } } diff --git a/src/hpc/parallel_search.rs b/src/hpc/parallel_search.rs index 588a2815..e5129085 100644 --- a/src/hpc/parallel_search.rs +++ b/src/hpc/parallel_search.rs @@ -7,8 +7,8 @@ //! Results are merged and filtered through TruthGate for evidence quality. use super::bgz17_bridge::PaletteEdge; +use super::layered_distance::{read_palette_edge, read_truth, TruthGate}; use super::palette_distance::SpoDistanceMatrices; -use super::layered_distance::{TruthGate, read_palette_edge, read_truth}; /// Search result with distance and truth metadata. #[derive(Debug, Clone)] @@ -45,14 +45,8 @@ pub struct PaletteScope { impl PaletteScope { /// Build from containers: extract palette edges from W125 of each. - pub fn from_containers( - containers: Vec<[u64; 256]>, - distances: SpoDistanceMatrices, - ) -> Self { - let palette_indices: Vec = containers - .iter() - .map(read_palette_edge) - .collect(); + pub fn from_containers(containers: Vec<[u64; 256]>, distances: SpoDistanceMatrices) -> Self { + let palette_indices: Vec = containers.iter().map(read_palette_edge).collect(); PaletteScope { palette_indices, distances, @@ -74,10 +68,8 @@ impl PaletteScope { #[inline] fn distance_to(&self, query: &PaletteEdge, idx: usize) -> u32 { let c = &self.palette_indices[idx]; - self.distances.spo_distance( - query.s_idx, query.p_idx, query.o_idx, - c.s_idx, c.p_idx, c.o_idx, - ) + self.distances + .spo_distance(query.s_idx, query.p_idx, query.o_idx, c.s_idx, c.p_idx, c.o_idx) } /// HHTL search: progressive refinement using palette distances. @@ -143,10 +135,8 @@ impl PaletteScope { // Update min distances for i in 0..n { let d = self.distances.spo_distance( - last_pe.s_idx, last_pe.p_idx, last_pe.o_idx, - self.palette_indices[i].s_idx, - self.palette_indices[i].p_idx, - self.palette_indices[i].o_idx, + last_pe.s_idx, last_pe.p_idx, last_pe.o_idx, self.palette_indices[i].s_idx, + self.palette_indices[i].p_idx, self.palette_indices[i].o_idx, ); if d < min_dists[i] { min_dists[i] = d; @@ -174,10 +164,9 @@ impl PaletteScope { let mut best_d = u32::MAX; for (a, &arch_idx) in archetype_indices.iter().enumerate() { let arch_pe = &self.palette_indices[arch_idx]; - let d = self.distances.spo_distance( - pe_i.s_idx, pe_i.p_idx, pe_i.o_idx, - arch_pe.s_idx, arch_pe.p_idx, arch_pe.o_idx, - ); + let d = self + .distances + .spo_distance(pe_i.s_idx, pe_i.p_idx, pe_i.o_idx, arch_pe.s_idx, arch_pe.p_idx, arch_pe.o_idx); if d < best_d { best_d = d; best_arch = a; @@ -237,12 +226,7 @@ impl PaletteScope { /// /// Both search paths run independently and their results are merged /// to produce the best top-k, filtered by truth-value evidence quality. -pub fn parallel_search( - scope: &PaletteScope, - query: &PaletteEdge, - k: usize, - gate: &TruthGate, -) -> Vec { +pub fn parallel_search(scope: &PaletteScope, query: &PaletteEdge, k: usize, gate: &TruthGate) -> Vec { if scope.is_empty() || k == 0 { return Vec::new(); } @@ -279,11 +263,7 @@ pub fn parallel_search( /// Merge and re-rank two result sets, taking union of top-k. /// /// Deduplicates by node index, keeping the minimum distance for each node. -fn merge_and_rerank( - hhtl: Vec<(usize, u32)>, - clam: Vec<(usize, u32)>, - k: usize, -) -> Vec<(usize, u32)> { +fn merge_and_rerank(hhtl: Vec<(usize, u32)>, clam: Vec<(usize, u32)>, k: usize) -> Vec<(usize, u32)> { // Collect all results into a map (node_idx -> min_distance) let mut map = std::collections::HashMap::new(); for (idx, d) in hhtl.into_iter().chain(clam.into_iter()) { @@ -304,11 +284,7 @@ fn merge_and_rerank( /// /// LFD = log2(|B(center, radius)| / |B(center, radius/2)|) /// where B(c, r) is the set of nodes within distance r of center. -pub fn lfd_from_palette( - scope: &PaletteScope, - center_idx: usize, - radius: u32, -) -> f64 { +pub fn lfd_from_palette(scope: &PaletteScope, center_idx: usize, radius: u32) -> f64 { if radius == 0 || scope.is_empty() { return 0.0; } @@ -325,10 +301,9 @@ pub fn lfd_from_palette( count_half_r += 1; continue; } - let d = scope.distances.spo_distance( - center.s_idx, center.p_idx, center.o_idx, - pe.s_idx, pe.p_idx, pe.o_idx, - ); + let d = scope + .distances + .spo_distance(center.s_idx, center.p_idx, center.o_idx, pe.s_idx, pe.p_idx, pe.o_idx); if d <= radius { count_r += 1; } @@ -346,10 +321,10 @@ pub fn lfd_from_palette( #[cfg(test)] mod tests { - use super::*; use super::super::bgz17_bridge::{Base17, PaletteEdge}; - use super::super::palette_distance::{Palette, SpoDistanceMatrices}; use super::super::layered_distance::write_palette_edge; + use super::super::palette_distance::{Palette, SpoDistanceMatrices}; + use super::*; fn make_test_scope(n: usize) -> PaletteScope { let entries: Vec = (0..32) @@ -384,7 +359,11 @@ mod tests { #[test] fn test_hhtl_search_basic() { let scope = make_test_scope(100); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = scope.hhtl_search(&query, 5); assert_eq!(results.len(), 5); // Results should be sorted by distance @@ -396,7 +375,11 @@ mod tests { #[test] fn test_hhtl_search_empty() { let scope = make_test_scope(0); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = scope.hhtl_search(&query, 5); assert!(results.is_empty()); } @@ -404,7 +387,11 @@ mod tests { #[test] fn test_hhtl_search_k_larger_than_n() { let scope = make_test_scope(3); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = scope.hhtl_search(&query, 10); assert_eq!(results.len(), 3); } @@ -412,7 +399,11 @@ mod tests { #[test] fn test_clam_search_basic() { let scope = make_test_scope(100); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = scope.clam_search(&query, 5); assert_eq!(results.len(), 5); for w in results.windows(2) { @@ -424,7 +415,11 @@ mod tests { fn test_clam_search_small_fallback() { // Small scope should fallback to HHTL let scope = make_test_scope(10); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let hhtl = scope.hhtl_search(&query, 5); let clam = scope.clam_search(&query, 5); // Should return the same results for small scope @@ -438,7 +433,11 @@ mod tests { #[test] fn test_parallel_search_basic() { let scope = make_test_scope(100); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = parallel_search(&scope, &query, 5, &TruthGate::OPEN); assert!(results.len() <= 5); assert!(!results.is_empty()); @@ -483,7 +482,11 @@ mod tests { let scope = PaletteScope::from_containers(containers, dm); // CERTAIN gate should filter out low-truth nodes - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let all = parallel_search(&scope, &query, 20, &TruthGate::OPEN); let certain = parallel_search(&scope, &query, 20, &TruthGate::CERTAIN); @@ -493,18 +496,18 @@ mod tests { // All results in certain should have high expectation for r in &certain { let exp = TruthGate::expectation(r.frequency, r.confidence); - assert!( - exp >= 0.9, - "result expectation {} should be >= 0.9", - exp - ); + assert!(exp >= 0.9, "result expectation {} should be >= 0.9", exp); } } #[test] fn test_parallel_search_empty() { let scope = make_test_scope(0); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = parallel_search(&scope, &query, 5, &TruthGate::OPEN); assert!(results.is_empty()); } @@ -582,7 +585,11 @@ mod tests { fn test_hhtl_finds_exact_match() { // Node 0 has palette edge (0, 0, 0), query is (0, 0, 0) => distance 0 let scope = make_test_scope(100); - let query = PaletteEdge { s_idx: 0, p_idx: 0, o_idx: 0 }; + let query = PaletteEdge { + s_idx: 0, + p_idx: 0, + o_idx: 0, + }; let results = scope.hhtl_search(&query, 1); assert_eq!(results.len(), 1); assert_eq!(results[0].1, 0, "exact match should have distance 0"); diff --git a/src/hpc/plane.rs b/src/hpc/plane.rs index c6c5c2e4..a94769ea 100644 --- a/src/hpc/plane.rs +++ b/src/hpc/plane.rs @@ -24,9 +24,7 @@ pub struct Acc16K { impl Default for Acc16K { fn default() -> Self { - Self { - values: [0i8; 16384], - } + Self { values: [0i8; 16384] } } } @@ -259,12 +257,7 @@ impl Plane { pub fn distance(&mut self, other: &mut Plane) -> Distance { self.ensure_cache(); other.ensure_cache(); - distance_slices( - self.bits_bytes_ref(), - self.alpha_bytes_ref(), - other.bits_bytes_ref(), - other.alpha_bytes_ref(), - ) + distance_slices(self.bits_bytes_ref(), self.alpha_bytes_ref(), other.bits_bytes_ref(), other.alpha_bytes_ref()) } // ======================================================================== @@ -400,11 +393,9 @@ impl Truth { evidence: 0, }; } - let f = ((self.frequency as u64 * self.evidence as u64) - + (other.frequency as u64 * other.evidence as u64)) + let f = ((self.frequency as u64 * self.evidence as u64) + (other.frequency as u64 * other.evidence as u64)) / total_evidence as u64; - let c = ((self.confidence as u64 * self.evidence as u64) - + (other.confidence as u64 * other.evidence as u64)) + let c = ((self.confidence as u64 * self.evidence as u64) + (other.confidence as u64 * other.evidence as u64)) / total_evidence as u64; Truth { frequency: f.min(65535) as u16, @@ -440,7 +431,11 @@ thread_local! { } pub fn distance_slices(a_bits: &[u8], a_alpha: &[u8], b_bits: &[u8], b_alpha: &[u8]) -> Distance { - let shared_len = a_bits.len().min(b_bits.len()).min(a_alpha.len()).min(b_alpha.len()); + let shared_len = a_bits + .len() + .min(b_bits.len()) + .min(a_alpha.len()) + .min(b_alpha.len()); if shared_len == 0 { return Distance::Incomparable; } @@ -571,9 +566,7 @@ mod tests { let d = a.distance(&mut b); match d { Distance::Measured { - disagreement, - overlap, - .. + disagreement, overlap, .. } => { assert!(overlap > 0); assert_eq!(disagreement, 0); @@ -658,7 +651,9 @@ mod tests { let d = learner.distance(&mut teacher); match d { - Distance::Measured { disagreement, overlap, .. } => { + Distance::Measured { + disagreement, overlap, .. + } => { assert!(overlap > 0); assert_eq!(disagreement, 0, "encounter_toward should match teacher bits"); } @@ -680,7 +675,9 @@ mod tests { let d = learner.distance(&mut target); match d { - Distance::Measured { disagreement, overlap, .. } => { + Distance::Measured { + disagreement, overlap, .. + } => { assert!(overlap > 0); assert_eq!(disagreement, overlap, "encounter_away should maximally disagree"); } @@ -723,7 +720,9 @@ mod tests { let d = learner.distance(&mut evidence); match d { - Distance::Measured { disagreement, overlap, .. } => { + Distance::Measured { + disagreement, overlap, .. + } => { assert_eq!(disagreement, overlap, "negative reward should punish"); } _ => panic!("expected Measured"), diff --git a/src/hpc/prefilter.rs b/src/hpc/prefilter.rs index 77d7a953..0c6c1420 100644 --- a/src/hpc/prefilter.rs +++ b/src/hpc/prefilter.rs @@ -187,12 +187,7 @@ pub fn top_k_rows_by_norm(data: &[f32], rows: usize, cols: usize, k: usize) -> V /// * `k` - Shared dimension /// * `prune_fraction` - 0.0-1.0, e.g. 0.9 = prune 90%, keep top 10% pub fn pruned_gemm_rows( - a: &[f32], - b: &[f32], - m: usize, - n: usize, - k: usize, - prune_fraction: f32, + a: &[f32], b: &[f32], m: usize, n: usize, k: usize, prune_fraction: f32, ) -> (Vec, Vec) { let norms = approx_row_norms_f32(a, m, k); @@ -255,11 +250,7 @@ pub fn pruned_gemm_rows( /// assert_eq!(results[0].1, 0); // distance 0 /// ``` pub fn approx_hamming_candidates( - query: &[u8], - database: &[u8], - bytes_per_vec: usize, - n_vectors: usize, - top_k: usize, + query: &[u8], database: &[u8], bytes_per_vec: usize, n_vectors: usize, top_k: usize, ) -> Vec<(usize, u32)> { assert!(database.len() >= n_vectors * bytes_per_vec); assert!(query.len() >= bytes_per_vec); @@ -268,10 +259,7 @@ pub fn approx_hamming_candidates( for v in 0..n_vectors { let vec_data = &database[v * bytes_per_vec..v * bytes_per_vec + bytes_per_vec]; - let dist = super::bitwise::hamming_distance_raw( - &query[..bytes_per_vec], - vec_data, - ) as u32; + let dist = super::bitwise::hamming_distance_raw(&query[..bytes_per_vec], vec_data) as u32; distances.push((v, dist)); } @@ -348,22 +336,11 @@ mod tests { let (mean, std) = approx_mean_std_f32(&data); let exact_mean: f32 = data.iter().sum::() / data.len() as f32; - let exact_var: f32 = - data.iter().map(|&x| (x - exact_mean).powi(2)).sum::() / data.len() as f32; + let exact_var: f32 = data.iter().map(|&x| (x - exact_mean).powi(2)).sum::() / data.len() as f32; let exact_std = exact_var.sqrt(); - assert!( - (mean - exact_mean).abs() < exact_mean.abs() * 0.02, - "mean: {} vs exact {}", - mean, - exact_mean - ); - assert!( - (std - exact_std).abs() < exact_std * 0.05, - "std: {} vs exact {}", - std, - exact_std - ); + assert!((mean - exact_mean).abs() < exact_mean.abs() * 0.02, "mean: {} vs exact {}", mean, exact_mean); + assert!((std - exact_std).abs() < exact_std * 0.05, "std: {} vs exact {}", std, exact_std); } #[test] @@ -415,9 +392,7 @@ mod tests { #[test] fn test_approx_column_std() { - let data = vec![ - 1.0, 10.0, 0.0, 2.0, 20.0, 0.0, 3.0, 30.0, 0.0, 4.0, 40.0, 0.0, - ]; + let data = vec![1.0, 10.0, 0.0, 2.0, 20.0, 0.0, 3.0, 30.0, 0.0, 4.0, 40.0, 0.0]; let stds = approx_column_std(&data, 4, 3); assert!(stds[0] > 0.5 && stds[0] < 2.0, "col0 std={}", stds[0]); assert!(stds[1] > 8.0 && stds[1] < 15.0, "col1 std={}", stds[1]); diff --git a/src/hpc/projection.rs b/src/hpc/projection.rs index e4b496f2..a23e8968 100644 --- a/src/hpc/projection.rs +++ b/src/hpc/projection.rs @@ -33,9 +33,13 @@ pub fn simhash_project(embedding: &[f32], container_bits: usize, seed: u64) -> A for bit_idx in 0..container_bits { // Generate random hyperplane using LCG let mut dot = 0.0f32; - let mut rng_state = seed.wrapping_mul(6364136223846793005).wrapping_add(bit_idx as u64); + let mut rng_state = seed + .wrapping_mul(6364136223846793005) + .wrapping_add(bit_idx as u64); for d in 0..dim { - rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(d as u64 + 1); + rng_state = rng_state + .wrapping_mul(6364136223846793005) + .wrapping_add(d as u64 + 1); // Convert to [-1, 1] range let random_val = ((rng_state >> 33) as f32 / (u32::MAX >> 1) as f32) * 2.0 - 1.0; dot += embedding[d] * random_val; @@ -68,11 +72,7 @@ pub fn simhash_project(embedding: &[f32], container_bits: usize, seed: u64) -> A /// assert_eq!(results[0].len(), 2); // 16 bits = 2 bytes /// ``` pub fn simhash_batch_project( - embeddings: &[f32], - n: usize, - d: usize, - container_bits: usize, - seed: u64, + embeddings: &[f32], n: usize, d: usize, container_bits: usize, seed: u64, ) -> Vec> { (0..n) .map(|i| { @@ -93,9 +93,13 @@ pub fn simhash_int8_project(embedding_i8: &[i8], container_bits: usize, seed: u6 for bit_idx in 0..container_bits { let mut dot = 0i64; - let mut rng_state = seed.wrapping_mul(6364136223846793005).wrapping_add(bit_idx as u64); + let mut rng_state = seed + .wrapping_mul(6364136223846793005) + .wrapping_add(bit_idx as u64); for d in 0..dim { - rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(d as u64 + 1); + rng_state = rng_state + .wrapping_mul(6364136223846793005) + .wrapping_add(d as u64 + 1); let random_val = if (rng_state >> 63) == 0 { 1i64 } else { -1i64 }; dot += embedding_i8[d] as i64 * random_val; } diff --git a/src/hpc/property_mask.rs b/src/hpc/property_mask.rs index ee108f03..f16f6dab 100644 --- a/src/hpc/property_mask.rs +++ b/src/hpc/property_mask.rs @@ -80,8 +80,7 @@ impl PropertyMask { /// Test a single block state against the compiled mask. #[inline(always)] pub fn test(&self, block_state: u64) -> bool { - (block_state & self.and_mask) == self.and_expect - && (block_state & self.andn_mask) == 0 + (block_state & self.and_mask) == self.and_expect && (block_state & self.andn_mask) == 0 } /// Batch test up to 4096 block states (one chunk section). @@ -254,9 +253,7 @@ impl PropertyMask { for lane in 0..4usize { let val = states[base + lane]; - if (val & self.and_mask) == self.and_expect - && (val & self.andn_mask) == 0 - { + if (val & self.and_mask) == self.and_expect && (val & self.andn_mask) == 0 { let idx = base + lane; result[idx / 64] |= 1u64 << (idx % 64); } @@ -424,9 +421,9 @@ mod tests { fn test_require_value() { // bits [2..4] must equal 2 (binary 10) let m = PropertyMask::new().require_value(2, 2, 2); - assert!(m.test(0b1000)); // field = 10 => 2 - assert!(!m.test(0b0100)); // field = 01 => 1 - assert!(!m.test(0b1100)); // field = 11 => 3 + assert!(m.test(0b1000)); // field = 10 => 2 + assert!(!m.test(0b0100)); // field = 01 => 1 + assert!(!m.test(0b1100)); // field = 11 => 3 assert!(m.test(0b11111_1000)); // field still 10 } @@ -605,11 +602,7 @@ mod tests { assert_eq!(result.counts.len(), masks.len()); for (m_idx, mask) in masks.iter().enumerate() { let expected = mask.count_section(&states); - assert_eq!( - result.counts[m_idx], expected, - "multi-mask parity mismatch for mask index {}", - m_idx - ); + assert_eq!(result.counts[m_idx], expected, "multi-mask parity mismatch for mask index {}", m_idx); } } } diff --git a/src/hpc/qualia.rs b/src/hpc/qualia.rs index 4536b96a..01ce2c92 100644 --- a/src/hpc/qualia.rs +++ b/src/hpc/qualia.rs @@ -418,7 +418,7 @@ mod tests { assert_eq!(q.channels[qualia_dim::PITCH], 0); // no transitions assert_eq!(q.channels[qualia_dim::WARMTH], 0); assert_eq!(q.channels[qualia_dim::PRESSURE], 0); // zero variance - // SOCIAL: all pairs agree perfectly (xor=0, agreement=8) + // SOCIAL: all pairs agree perfectly (xor=0, agreement=8) assert_eq!(q.channels[qualia_dim::SOCIAL], 65535); // TEMPORAL: no gradient => 0.5 * 65535 = 32767 assert_eq!(q.channels[qualia_dim::TEMPORAL], 32767); @@ -444,10 +444,10 @@ mod tests { assert_eq!(q.channels[qualia_dim::PITCH], 0); // no transitions (xor of identical bytes = 0) assert_eq!(q.channels[qualia_dim::WARMTH], 65535); assert_eq!(q.channels[qualia_dim::PRESSURE], 0); // zero variance - // SOCIAL: all pairs identical => max agreement + // SOCIAL: all pairs identical => max agreement assert_eq!(q.channels[qualia_dim::SOCIAL], 65535); assert_eq!(q.channels[qualia_dim::TEMPORAL], 32767); // no gradient - // SACREDNESS: only one byte value => entropy = 0 + // SACREDNESS: only one byte value => entropy = 0 assert_eq!(q.channels[qualia_dim::SACREDNESS], 0); // AROUSAL: 1.0^2 = 1.0 assert_eq!(q.channels[qualia_dim::AROUSAL], 65535); @@ -487,12 +487,7 @@ mod tests { let b = qualia_color(&data_b); let d_ab = qualia_distance(&a, &b); let d_ba = qualia_distance(&b, &a); - assert!( - (d_ab - d_ba).abs() < 1e-6, - "distance should be symmetric: {} vs {}", - d_ab, - d_ba - ); + assert!((d_ab - d_ba).abs() < 1e-6, "distance should be symmetric: {} vs {}", d_ab, d_ba); } #[test] @@ -500,11 +495,7 @@ mod tests { let a = QualiaVector { channels: [0; 16] }; let b = QualiaVector { channels: [65535; 16] }; let d = qualia_distance(&a, &b); - assert!( - (d - 1.0).abs() < 1e-5, - "max distance should be ~1.0, got {}", - d - ); + assert!((d - 1.0).abs() < 1e-5, "max distance should be ~1.0, got {}", d); } #[test] @@ -537,10 +528,7 @@ mod tests { let b = qualia_color(&data_b); let d = qualia_distance(&a, &b); let s = qualia_similarity(&a, &b); - assert!( - (d + s - 1.0).abs() < 1e-6, - "distance + similarity should = 1.0" - ); + assert!((d + s - 1.0).abs() < 1e-6, "distance + similarity should = 1.0"); } #[test] @@ -565,10 +553,7 @@ mod tests { #[test] fn test_qualia_roundtrip_packed() { // Create a PackedQualia with known values - let packed = PackedQualia::new( - [0, 10, -10, 50, -50, 100, -100, 127, -128, 1, -1, 42, -42, 63, -63, 0], - 0.0, - ); + let packed = PackedQualia::new([0, 10, -10, 50, -50, 100, -100, 127, -128, 1, -1, 42, -42, 63, -63, 0], 0.0); let qv = qualia_from_packed(&packed); let packed2 = qualia_to_packed(&qv); diff --git a/src/hpc/qualia_gate.rs b/src/hpc/qualia_gate.rs index 749062c1..dbf78bf6 100644 --- a/src/hpc/qualia_gate.rs +++ b/src/hpc/qualia_gate.rs @@ -259,10 +259,7 @@ mod tests { fn test_gate_level_from_str() { assert_eq!(QualiaGateLevel::parse("flow"), Some(QualiaGateLevel::Flow)); assert_eq!(QualiaGateLevel::parse("hold"), Some(QualiaGateLevel::Hold)); - assert_eq!( - QualiaGateLevel::parse("block"), - Some(QualiaGateLevel::Block) - ); + assert_eq!(QualiaGateLevel::parse("block"), Some(QualiaGateLevel::Block)); assert_eq!(QualiaGateLevel::parse("invalid"), None); } diff --git a/src/hpc/quantized.rs b/src/hpc/quantized.rs index 124efc28..50bf9432 100644 --- a/src/hpc/quantized.rs +++ b/src/hpc/quantized.rs @@ -441,16 +441,7 @@ impl_half_ops!(F16); /// BF16 GEMM with f32 accumulation: C = alpha * A * B + beta * C /// /// A and B are BF16, C is f32. Accumulation done in f32 for precision. -pub fn bf16_gemm_f32( - a: &[BF16], - b: &[BF16], - c: &mut [f32], - m: usize, - n: usize, - k: usize, - alpha: f32, - beta: f32, -) { +pub fn bf16_gemm_f32(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize, alpha: f32, beta: f32) { // Apply beta if beta == 0.0 { for v in c.iter_mut() { @@ -477,8 +468,7 @@ pub fn bf16_gemm_f32( for p in 0..kb { let a_val = alpha * a[(ii + i) * k + (kk + p)].to_f32(); for j in 0..jb { - c[(ii + i) * n + (jj + j)] += - a_val * b[(kk + p) * n + (jj + j)].to_f32(); + c[(ii + i) * n + (jj + j)] += a_val * b[(kk + p) * n + (jj + j)].to_f32(); } } } @@ -492,14 +482,7 @@ pub fn bf16_gemm_f32( /// Mixed precision GEMM: f32 inputs, BF16 compute, f32 output. pub fn mixed_precision_gemm( - a_f32: &[f32], - b_f32: &[f32], - c: &mut [f32], - m: usize, - n: usize, - k: usize, - alpha: f32, - beta: f32, + a_f32: &[f32], b_f32: &[f32], c: &mut [f32], m: usize, n: usize, k: usize, alpha: f32, beta: f32, ) { let a_bf16 = f32_vec_to_bf16(a_f32); let b_bf16 = f32_vec_to_bf16(b_f32); @@ -547,7 +530,15 @@ pub fn quantize_f32_to_u8(data: &[f32]) -> (Vec, QuantParams) { .map(|&v| ((v / scale + zero_point as f32).round() as i32).clamp(0, 255) as u8) .collect(); - (quantized, QuantParams { scale, zero_point, min_val, max_val }) + ( + quantized, + QuantParams { + scale, + zero_point, + min_val, + max_val, + }, + ) } /// Quantize f32 to i8. @@ -562,7 +553,15 @@ pub fn quantize_f32_to_i8(data: &[f32]) -> (Vec, QuantParams) { .map(|&v| (v / scale).round().clamp(-128.0, 127.0) as i8) .collect(); - (quantized, QuantParams { scale, zero_point: 0, min_val, max_val }) + ( + quantized, + QuantParams { + scale, + zero_point: 0, + min_val, + max_val, + }, + ) } /// Dequantize i8 codes back to f32 using the [`QuantParams`] from @@ -593,11 +592,7 @@ pub fn dequantize_i8_to_f32(codes: &[i8], params: &QuantParams, n: usize) -> Vec } /// Per-channel i8 quantization (per row). -pub fn quantize_per_channel_i8( - data: &[f32], - rows: usize, - cols: usize, -) -> (Vec, PerChannelQuantParams) { +pub fn quantize_per_channel_i8(data: &[f32], rows: usize, cols: usize) -> (Vec, PerChannelQuantParams) { let mut quantized = vec![0i8; data.len()]; let mut scales = Vec::with_capacity(rows); let mut zero_points = Vec::with_capacity(rows); @@ -636,15 +631,7 @@ pub fn int8_gemm_i32(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: u /// Int8 GEMM with f32 dequantization. pub fn int8_gemm_f32( - a: &[u8], - b: &[i8], - c: &mut [f32], - m: usize, - n: usize, - k: usize, - scale_a: f32, - zero_point_a: i32, - scale_b: f32, + a: &[u8], b: &[i8], c: &mut [f32], m: usize, n: usize, k: usize, scale_a: f32, zero_point_a: i32, scale_b: f32, ) { let mut c_i32 = vec![0i32; m * n]; int8_gemm_i32(a, b, &mut c_i32, m, n, k); @@ -666,14 +653,7 @@ pub fn int8_gemm_f32( /// Per-channel int8 GEMM with f32 output. pub fn int8_gemm_per_channel_f32( - a: &[u8], - b: &[i8], - c: &mut [f32], - m: usize, - n: usize, - k: usize, - a_scales: &[f32], - a_zero_points: &[i32], + a: &[u8], b: &[i8], c: &mut [f32], m: usize, n: usize, k: usize, a_scales: &[f32], a_zero_points: &[i32], b_scales: &[f32], ) { for i in 0..m { @@ -721,11 +701,7 @@ pub fn dequantize_i4_to_f32(packed: &[u8], params: &QuantParams, len: usize) -> let mut result = Vec::with_capacity(len); for i in 0..len { let byte = packed[i / 2]; - let nibble = if i % 2 == 0 { - byte & 0x0F - } else { - byte >> 4 - }; + let nibble = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 }; // Sign-extend from 4 bits let val = if nibble & 0x08 != 0 { nibble as i8 | !0x0F_u8 as i8 @@ -1051,12 +1027,7 @@ mod tests { for &v in &approx { let h = F16::from_f32(v); let back = h.to_f32(); - assert!( - (back - v).abs() / v.abs().max(1.0) < 0.001, - "F16 roundtrip {} → {}", - v, - back - ); + assert!((back - v).abs() / v.abs().max(1.0) < 0.001, "F16 roundtrip {} → {}", v, back); } } diff --git a/src/hpc/renderer.rs b/src/hpc/renderer.rs index b7cab349..eab6a0a0 100644 --- a/src/hpc/renderer.rs +++ b/src/hpc/renderer.rs @@ -101,10 +101,7 @@ impl RenderFrame { /// Total bytes resident for this frame (debug / health). pub fn byte_footprint(&self) -> usize { - self.positions.len() * 4 - + self.velocities.len() * 4 - + self.charges.len() * 4 - + self.fingerprints.len() * 8 + self.positions.len() * 4 + self.velocities.len() * 4 + self.charges.len() * 4 + self.fingerprints.len() * 8 } } @@ -135,10 +132,7 @@ impl Renderer { /// Allocate a renderer with capacity for `n` nodes per frame. pub fn with_capacity(n: usize) -> Self { Self { - frames: [ - RwLock::new(RenderFrame::with_capacity(n)), - RwLock::new(RenderFrame::with_capacity(n)), - ], + frames: [RwLock::new(RenderFrame::with_capacity(n)), RwLock::new(RenderFrame::with_capacity(n))], front_idx: AtomicUsize::new(0), tick_count: AtomicU64::new(0), } @@ -158,12 +152,16 @@ impl Renderer { /// Read-lock the front frame (for REST / SSE consumers). pub fn read_front(&self) -> std::sync::RwLockReadGuard<'_, RenderFrame> { - self.frames[self.front_index()].read().expect("front lock poisoned") + self.frames[self.front_index()] + .read() + .expect("front lock poisoned") } /// Write-lock the back frame (for the shader cycle to mutate). pub fn write_back(&self) -> std::sync::RwLockWriteGuard<'_, RenderFrame> { - self.frames[self.back_index()].write().expect("back lock poisoned") + self.frames[self.back_index()] + .write() + .expect("back lock poisoned") } /// Atomically swap front and back. Readers acquired BEFORE the swap @@ -186,7 +184,12 @@ impl Renderer { pub fn tick(&self, dt: f32, damping: f32) { { let mut back = self.write_back(); - let RenderFrame { positions, velocities, tick, .. } = &mut *back; + let RenderFrame { + positions, + velocities, + tick, + .. + } = &mut *back; integrate_simd(positions, velocities, dt, damping); *tick = self.tick_count.load(Ordering::Acquire) + 1; } @@ -206,8 +209,7 @@ impl Default for Renderer { /// Capacity is bootstrapped at 4096 nodes (rounded up to PREFERRED_F32_LANES). /// Consumers wanting a different capacity should construct their own /// `Renderer::with_capacity(...)` in their binary, not touch this static. -pub static GLOBAL_RENDERER: LazyLock = - LazyLock::new(|| Renderer::with_capacity(4096)); +pub static GLOBAL_RENDERER: LazyLock = LazyLock::new(|| Renderer::with_capacity(4096)); // ───────────────────────────────────────────────────────────────────── // SIMD hot path — integrate_simd dispatches via crate::simd::F32x16 @@ -406,7 +408,7 @@ mod tests { let mut velocities = vec![0.0f32; 48]; apply_uniform_force(&mut velocities, [1.0, 2.0, 3.0], 0.5); for n in 0..16 { - assert!((velocities[n * 3] - 0.5).abs() < 1e-6); // X: 1·0.5 + assert!((velocities[n * 3] - 0.5).abs() < 1e-6); // X: 1·0.5 assert!((velocities[n * 3 + 1] - 1.0).abs() < 1e-6); // Y: 2·0.5 assert!((velocities[n * 3 + 2] - 1.5).abs() < 1e-6); // Z: 3·0.5 } @@ -459,10 +461,15 @@ static SPLAT_15: LazyLock = LazyLock::new(|| F32x16::splat(DT_15)); #[inline] pub fn cached_splat(dt: f32) -> F32x16 { const TOL: f32 = 2e-6; - if (dt - DT_60).abs() < TOL { *SPLAT_60 } - else if (dt - DT_30).abs() < TOL { *SPLAT_30 } - else if (dt - DT_15).abs() < TOL { *SPLAT_15 } - else { F32x16::splat(dt) } + if (dt - DT_60).abs() < TOL { + *SPLAT_60 + } else if (dt - DT_30).abs() < TOL { + *SPLAT_30 + } else if (dt - DT_15).abs() < TOL { + *SPLAT_15 + } else { + F32x16::splat(dt) + } } // ───────────────────────────────────────────────────────────────────── @@ -491,7 +498,12 @@ pub struct Viewport { impl Viewport { /// Default: 4.0 unit foveal, 16.0 peripheral, 64.0 cull. pub fn default_at(center: [f32; 3]) -> Self { - Self { center, foveal_radius: 4.0, peripheral_radius: 16.0, cull_radius: 64.0 } + Self { + center, + foveal_radius: 4.0, + peripheral_radius: 16.0, + cull_radius: 64.0, + } } } @@ -525,7 +537,11 @@ impl UpdatePriority { #[inline] pub fn should_update(self, tick: u64) -> bool { let stride = self.tick_stride(); - if stride == u64::MAX { false } else { tick % stride == 0 } + if stride == u64::MAX { + false + } else { + tick % stride == 0 + } } } @@ -539,16 +555,19 @@ pub fn classify_priorities(positions: &[f32], len: usize, vp: &Viewport) -> Vec< let p2 = vp.peripheral_radius * vp.peripheral_radius; let c2 = vp.cull_radius * vp.cull_radius; for i in 0..len { - let dx = positions[i * POSITION_DIMS] - vp.center[0]; + let dx = positions[i * POSITION_DIMS] - vp.center[0]; let dy = positions[i * POSITION_DIMS + 1] - vp.center[1]; let dz = positions[i * POSITION_DIMS + 2] - vp.center[2]; let d2 = dx * dx + dy * dy + dz * dz; - out.push( - if d2 <= f2 { UpdatePriority::Foveal } - else if d2 <= p2 { UpdatePriority::Peripheral } - else if d2 <= c2 { UpdatePriority::Distant } - else { UpdatePriority::OffScreen } - ); + out.push(if d2 <= f2 { + UpdatePriority::Foveal + } else if d2 <= p2 { + UpdatePriority::Peripheral + } else if d2 <= c2 { + UpdatePriority::Distant + } else { + UpdatePriority::OffScreen + }); } out } @@ -563,12 +582,7 @@ pub fn classify_priorities(positions: &[f32], len: usize, vp: &Viewport) -> Vec< /// random-priority graphs, foveated savings drop toward zero (worst case /// is the same cost as `integrate_simd`). pub fn integrate_foveated( - positions: &mut [f32], - velocities: &mut [f32], - priorities: &[UpdatePriority], - tick: u64, - dt: f32, - damping: f32, + positions: &mut [f32], velocities: &mut [f32], priorities: &[UpdatePriority], tick: u64, dt: f32, damping: f32, ) { debug_assert_eq!(positions.len(), velocities.len()); debug_assert_eq!(positions.len() % PREFERRED_F32_LANES, 0); @@ -591,9 +605,10 @@ pub fn integrate_foveated( let node_hi = (node_lo + nodes_per_chunk).min(priorities.len()); // Skip only if every node in the band agrees to skip THIS tick. - let all_skip = (node_lo..node_hi) - .all(|n| !priorities[n].should_update(tick)); - if all_skip { continue; } + let all_skip = (node_lo..node_hi).all(|n| !priorities[n].should_update(tick)); + if all_skip { + continue; + } let pv = F32x16::from_array(*p); let vv = F32x16::from_array(*v); @@ -661,7 +676,7 @@ impl FpsController { match self.target_hz() { 60 => DT_60, 30 => DT_30, - _ => DT_15, + _ => DT_15, } } @@ -688,8 +703,7 @@ impl FpsController { let next = if prev == 0 { duration_ns } else { - prev + (duration_ns.saturating_sub(prev) / 8) - - (prev.saturating_sub(duration_ns) / 8) + prev + (duration_ns.saturating_sub(prev) / 8) - (prev.saturating_sub(duration_ns) / 8) }; self.avg_tick_ns.store(next, Ordering::Release); @@ -698,14 +712,22 @@ impl FpsController { if duration_ns > budget { self.under_budget_streak.store(0, Ordering::Release); // Step down 60 → 30 → 15. Don't go below 15. - let new_hz = match cur { 60 => 30, 30 => 15, _ => 15 }; + let new_hz = match cur { + 60 => 30, + 30 => 15, + _ => 15, + }; if new_hz != cur { self.target_hz.store(new_hz, Ordering::Release); } } else { let streak = self.under_budget_streak.fetch_add(1, Ordering::AcqRel) + 1; if streak >= 60 { - let new_hz = match cur { 15 => 30, 30 => 60, _ => 60 }; + let new_hz = match cur { + 15 => 30, + 30 => 60, + _ => 60, + }; if new_hz != cur { self.target_hz.store(new_hz, Ordering::Release); } @@ -754,7 +776,13 @@ impl Renderer { let tick_now = self.tick_count.load(Ordering::Acquire) + 1; { let mut back = self.write_back(); - let RenderFrame { positions, velocities, len, tick, .. } = &mut *back; + let RenderFrame { + positions, + velocities, + len, + tick, + .. + } = &mut *back; let priorities = classify_priorities(positions, *len, vp); integrate_foveated(positions, velocities, &priorities, tick_now, dt, damping); *tick = tick_now; @@ -787,7 +815,9 @@ mod adaptive_tests { let v = cached_splat(0.0314); let mut out = [0.0f32; 16]; v.copy_to_slice(&mut out); - for x in out { assert!((x - 0.0314).abs() < 1e-6); } + for x in out { + assert!((x - 0.0314).abs() < 1e-6); + } } #[test] @@ -826,10 +856,10 @@ mod adaptive_tests { fn classify_priorities_assigns_zones() { // 4 nodes: at center, foveal-edge, peripheral-zone, off-screen. let positions = vec![ - 0.0, 0.0, 0.0, // node 0 — at center → Foveal - 3.0, 0.0, 0.0, // node 1 — within foveal radius (4) - 8.0, 0.0, 0.0, // node 2 — within peripheral (16) - 70.0, 0.0, 0.0, // node 3 — beyond cull (64) + 0.0, 0.0, 0.0, // node 0 — at center → Foveal + 3.0, 0.0, 0.0, // node 1 — within foveal radius (4) + 8.0, 0.0, 0.0, // node 2 — within peripheral (16) + 70.0, 0.0, 0.0, // node 3 — beyond cull (64) ]; let vp = Viewport::default_at([0.0, 0.0, 0.0]); let p = classify_priorities(&positions, 4, &vp); @@ -846,8 +876,12 @@ mod adaptive_tests { let mut velocities = vec![1.0f32; 32]; let priorities = vec![UpdatePriority::OffScreen; 12]; // covers both chunks integrate_foveated(&mut positions, &mut velocities, &priorities, 0, 0.5, 0.9); - for &p in &positions { assert_eq!(p, 1.0); } // unchanged - for &v in &velocities { assert_eq!(v, 1.0); } // unchanged + for &p in &positions { + assert_eq!(p, 1.0); + } // unchanged + for &v in &velocities { + assert_eq!(v, 1.0); + } // unchanged } #[test] @@ -871,10 +905,14 @@ mod adaptive_tests { let priorities = vec![UpdatePriority::Peripheral; 12]; // Tick 1 (odd) — peripheral skips integrate_foveated(&mut positions, &mut velocities, &priorities, 1, 0.5, 0.9); - for &p in &positions { assert_eq!(p, 0.0); } + for &p in &positions { + assert_eq!(p, 0.0); + } // Tick 2 (even) — peripheral updates integrate_foveated(&mut positions, &mut velocities, &priorities, 2, 0.5, 0.9); - for &p in &positions { assert!((p - 0.5).abs() < 1e-6); } + for &p in &positions { + assert!((p - 0.5).abs() < 1e-6); + } } #[test] @@ -900,10 +938,14 @@ mod adaptive_tests { fn fps_controller_steps_up_on_sustained_under_budget() { let c = FpsController::new(15); // Record 60 fast ticks → climb to 30. - for _ in 0..60 { c.record_tick(1_000_000); } // 1ms each + for _ in 0..60 { + c.record_tick(1_000_000); + } // 1ms each assert_eq!(c.target_hz(), 30); // Another 60 fast → climb to 60. - for _ in 0..60 { c.record_tick(1_000_000); } + for _ in 0..60 { + c.record_tick(1_000_000); + } assert_eq!(c.target_hz(), 60); } diff --git a/src/hpc/safetensors.rs b/src/hpc/safetensors.rs index 5d70218d..f9fb29c2 100644 --- a/src/hpc/safetensors.rs +++ b/src/hpc/safetensors.rs @@ -21,7 +21,7 @@ //! safetensors stores full BF16 precision, while GGUF Q8_0 introduces //! quantization noise. BF16→BF16 diff gives cleaner causal attribution. -use super::gguf::{GgufFile, TensorInfo, GgmlType}; +use super::gguf::{GgmlType, GgufFile, TensorInfo}; use std::collections::HashMap; use std::io::{Read, Seek, SeekFrom}; @@ -68,13 +68,21 @@ fn parse_safetensors_json(json: &str) -> Result, String> { while pos < len - 1 { // Skip whitespace and commas - while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n' || - bytes[pos] == b'\r' || bytes[pos] == b'\t' || - bytes[pos] == b',') { + while pos < len + && (bytes[pos] == b' ' + || bytes[pos] == b'\n' + || bytes[pos] == b'\r' + || bytes[pos] == b'\t' + || bytes[pos] == b',') + { pos += 1; } - if pos >= len - 1 { break; } - if bytes[pos] == b'}' { break; } + if pos >= len - 1 { + break; + } + if bytes[pos] == b'}' { + break; + } // Read key (tensor name) if bytes[pos] != b'"' { @@ -84,19 +92,22 @@ fn parse_safetensors_json(json: &str) -> Result, String> { let key_start = pos + 1; pos += 1; while pos < len && bytes[pos] != b'"' { - if bytes[pos] == b'\\' { pos += 1; } // skip escaped char + if bytes[pos] == b'\\' { + pos += 1; + } // skip escaped char pos += 1; } let key = &json[key_start..pos]; pos += 1; // skip closing " // Skip colon - while pos < len && bytes[pos] != b':' { pos += 1; } + while pos < len && bytes[pos] != b':' { + pos += 1; + } pos += 1; // skip : // Skip whitespace - while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n' || - bytes[pos] == b'\r' || bytes[pos] == b'\t') { + while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n' || bytes[pos] == b'\r' || bytes[pos] == b'\t') { pos += 1; } @@ -107,12 +118,18 @@ fn parse_safetensors_json(json: &str) -> Result, String> { let mut depth = 1; pos += 1; while pos < len && depth > 0 { - if bytes[pos] == b'{' { depth += 1; } - if bytes[pos] == b'}' { depth -= 1; } + if bytes[pos] == b'{' { + depth += 1; + } + if bytes[pos] == b'}' { + depth -= 1; + } if bytes[pos] == b'"' { pos += 1; while pos < len && bytes[pos] != b'"' { - if bytes[pos] == b'\\' { pos += 1; } + if bytes[pos] == b'\\' { + pos += 1; + } pos += 1; } } @@ -125,7 +142,9 @@ fn parse_safetensors_json(json: &str) -> Result, String> { // Parse tensor value object: { "dtype": "...", "shape": [...], "data_offsets": [...] } if bytes[pos] != b'{' { // Not an object — skip until next comma or closing brace - while pos < len && bytes[pos] != b',' && bytes[pos] != b'}' { pos += 1; } + while pos < len && bytes[pos] != b',' && bytes[pos] != b'}' { + pos += 1; + } continue; } @@ -134,12 +153,18 @@ fn parse_safetensors_json(json: &str) -> Result, String> { let mut depth = 1; pos += 1; while pos < len && depth > 0 { - if bytes[pos] == b'{' { depth += 1; } - if bytes[pos] == b'}' { depth -= 1; } + if bytes[pos] == b'{' { + depth += 1; + } + if bytes[pos] == b'}' { + depth -= 1; + } if bytes[pos] == b'"' { pos += 1; while pos < len && bytes[pos] != b'"' { - if bytes[pos] == b'\\' { pos += 1; } + if bytes[pos] == b'\\' { + pos += 1; + } pos += 1; } } @@ -201,7 +226,8 @@ fn extract_json_array_u64(obj: &str, key: &str) -> Option> { let bracket_close = after_key.find(']')?; let array_str = &after_key[bracket_open + 1..bracket_close]; - let values: Vec = array_str.split(',') + let values: Vec = array_str + .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); @@ -222,7 +248,9 @@ fn extract_json_array_u64(obj: &str, key: &str) -> Option> { pub fn read_safetensors_header(reader: &mut R) -> Result { // Read header size (first 8 bytes, u64 LE) let mut size_buf = [0u8; 8]; - reader.read_exact(&mut size_buf).map_err(|e| format!("read header size: {}", e))?; + reader + .read_exact(&mut size_buf) + .map_err(|e| format!("read header size: {}", e))?; let header_size = u64::from_le_bytes(size_buf); if header_size > 100_000_000 { @@ -231,7 +259,9 @@ pub fn read_safetensors_header(reader: &mut R) -> Result(reader: &mut R) -> Result(reader: &mut R) -> Result( - reader: &mut R, - writer: &mut W, - octave_stride: usize, + reader: &mut R, writer: &mut W, octave_stride: usize, callback: Option<&dyn Fn(&str, &super::gguf_indexer::LayerType, usize, usize)>, ) -> Result { // Parse safetensors header (produces GgufFile-compatible struct) @@ -276,9 +303,7 @@ pub fn stream_index_safetensors_bf16( // Delegate to the existing BF16-direct chunked indexer // The indexer uses: header.tensors, header.tensor_data_offset, tensor.offset, tensor.dtype // All of these are populated by read_safetensors_header identically to read_gguf_header. - super::gguf_indexer::stream_index_gguf_bf16_with_header( - reader, writer, &header, octave_stride, callback, - ) + super::gguf_indexer::stream_index_gguf_bf16_with_header(reader, writer, &header, octave_stride, callback) } // ============================================================================ @@ -387,7 +412,8 @@ mod tests { .output() .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) .unwrap_or_default(); - let size: u64 = size_str.lines() + let size: u64 = size_str + .lines() .filter(|l| l.to_lowercase().starts_with("content-length:")) .last() .and_then(|l| l.split(':').nth(1)) @@ -399,17 +425,22 @@ mod tests { let mut writer = BufWriter::new(out); let stats = stream_index_safetensors_bf16( - &mut reader, &mut writer, 16, + &mut reader, + &mut writer, + 16, Some(&|name, lt, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; eprintln!(" {:50} {:>12} → {:>8} ({:.0}×)", name, orig, comp, ratio); }), - ).expect("safetensors indexing failed"); + ) + .expect("safetensors indexing failed"); drop(writer); - eprintln!(" → {:.2} MB, {} tensors", + eprintln!( + " → {:.2} MB, {} tensors", std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0) as f64 / 1e6, - stats.tensors_indexed); + stats.tensors_indexed + ); } } @@ -417,10 +448,7 @@ mod tests { /// Helper: index safetensors shards from a HuggingFace repo. fn index_safetensors_shards( - repo: &str, - filenames: &[&str], - out_prefix: &str, - octave_stride: usize, + repo: &str, filenames: &[&str], out_prefix: &str, octave_stride: usize, ) -> Vec { use super::super::http_reader::HttpRangeReader; use std::io::BufWriter; @@ -449,7 +477,8 @@ mod tests { .output() .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) .unwrap_or_default(); - let size: u64 = size_str.lines() + let size: u64 = size_str + .lines() .filter(|l| l.to_lowercase().starts_with("content-length:")) .last() .and_then(|l| l.split(':').nth(1)) @@ -461,17 +490,24 @@ mod tests { let mut writer = BufWriter::new(out); let stats = super::stream_index_safetensors_bf16( - &mut reader, &mut writer, octave_stride, + &mut reader, + &mut writer, + octave_stride, Some(&|name, lt, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; eprintln!(" {:50} {:>12} → {:>8} ({:.0}×)", name, orig, comp, ratio); }), - ).expect("safetensors indexing failed"); + ) + .expect("safetensors indexing failed"); drop(writer); let out_size = std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0); - eprintln!(" → {:.2} MB, {} tensors, {:.0}×", - out_size as f64 / 1e6, stats.tensors_indexed, stats.overall_ratio()); + eprintln!( + " → {:.2} MB, {} tensors, {:.0}×", + out_size as f64 / 1e6, + stats.tensors_indexed, + stats.overall_ratio() + ); all_stats.push(stats); } @@ -483,12 +519,13 @@ mod tests { #[ignore] // Streams ~35 GB from HuggingFace fn test_stream_index_hidream_transformer() { let repo = "HiDream-ai/HiDream-I1-Full"; - let shards: Vec<&str> = (1..=7).map(|i| { - // Leak the string so it lives long enough — test only - Box::leak(format!( - "transformer/diffusion_pytorch_model-{:05}-of-00007.safetensors", i - ).into_boxed_str()) as &str - }).collect(); + let shards: Vec<&str> = (1..=7) + .map(|i| { + // Leak the string so it lives long enough — test only + Box::leak(format!("transformer/diffusion_pytorch_model-{:05}-of-00007.safetensors", i).into_boxed_str()) + as &str + }) + .collect(); let stats = index_safetensors_shards(repo, &shards, "/tmp/hidream_transformer", 16); @@ -513,33 +550,29 @@ mod tests { // CLIP-L eprintln!("━━━ CLIP-L ━━━"); - index_safetensors_shards(repo, - &["text_encoder/model.safetensors"], - "/tmp/hidream_clip_l", 16); + index_safetensors_shards(repo, &["text_encoder/model.safetensors"], "/tmp/hidream_clip_l", 16); // CLIP-G eprintln!("━━━ CLIP-G ━━━"); - index_safetensors_shards(repo, - &["text_encoder_2/model.safetensors"], - "/tmp/hidream_clip_g", 16); + index_safetensors_shards(repo, &["text_encoder_2/model.safetensors"], "/tmp/hidream_clip_g", 16); // Llama-3.1-8B text encoder (2 shards) eprintln!("━━━ Llama-3.1-8B (HiDream text encoder) ━━━"); - index_safetensors_shards(repo, - &["text_encoder_3/model-00001-of-00002.safetensors", - "text_encoder_3/model-00002-of-00002.safetensors"], - "/tmp/hidream_llama_enc", 16); + index_safetensors_shards( + repo, + &["text_encoder_3/model-00001-of-00002.safetensors", "text_encoder_3/model-00002-of-00002.safetensors"], + "/tmp/hidream_llama_enc", + 16, + ); } #[test] #[ignore] // Streams ~16 GB (base Llama-3.1-8B) fn test_stream_index_llama31_8b_base() { let repo = "unsloth/Llama-3.1-8B"; - let shards: Vec<&str> = (1..=4).map(|i| { - Box::leak(format!( - "model-{:05}-of-00004.safetensors", i - ).into_boxed_str()) as &str - }).collect(); + let shards: Vec<&str> = (1..=4) + .map(|i| Box::leak(format!("model-{:05}-of-00004.safetensors", i).into_boxed_str()) as &str) + .collect(); index_safetensors_shards(repo, &shards, "/tmp/llama31_8b_base", 16); } @@ -547,7 +580,7 @@ mod tests { #[test] #[ignore] // Requires: HiDream Llama enc + base Llama indexed fn test_hidream_llama_diff() { - use super::super::causal_diff::{causal_diff, print_diff_summary, find_reasoning_scaffold}; + use super::super::causal_diff::{causal_diff, find_reasoning_scaffold, print_diff_summary}; // Compare HiDream's Llama-3.1-8B (image-conditioned) vs base // Shards need to be concatenated or diffed per-shard @@ -568,7 +601,9 @@ mod tests { let (edges, stats) = causal_diff(base, dist, 100).expect("diff failed"); print_diff_summary( &format!("Llama-3.1-8B: base vs HiDream image encoder ({})", label), - &stats, edges.len()); + &stats, + edges.len(), + ); let scaffold = find_reasoning_scaffold(&edges, 0.3); eprintln!(" Visual grounding scaffold blocks: {:?}", scaffold); @@ -580,9 +615,12 @@ mod tests { if total_compared > 0 { eprintln!(); eprintln!("━━━ Cross-Domain Insight ━━━"); - eprintln!(" Total rows shifted: {}/{} ({:.1}%)", - total_shifted, total_compared, - total_shifted as f64 / total_compared as f64 * 100.0); + eprintln!( + " Total rows shifted: {}/{} ({:.1}%)", + total_shifted, + total_compared, + total_shifted as f64 / total_compared as f64 * 100.0 + ); eprintln!(" → These shifts = what 'visual grounding' looks like in LLM weight space"); } } @@ -594,12 +632,7 @@ mod tests { fn test_stream_index_reader_lm() { // jinaai/reader-lm-1.5b: 1 shard, 1.54B params, 3.1 GB BF16 // Produces ~30 MB bgz7 for local HTML→Markdown palette routing - index_safetensors_shards( - "jinaai/reader-lm-1.5b", - &["model.safetensors"], - "/tmp/reader_lm_1_5b", - 16, - ); + index_safetensors_shards("jinaai/reader-lm-1.5b", &["model.safetensors"], "/tmp/reader_lm_1_5b", 16); } // ── BGE-M3: multilingual embedding model (GGUF path) ── @@ -639,16 +672,23 @@ mod tests { let mut writer = BufWriter::new(out); let stats = super::super::gguf_indexer::stream_index_gguf_bf16( - &mut reader, &mut writer, 16, + &mut reader, + &mut writer, + 16, Some(&|name, _lt, orig, comp| { let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; eprintln!(" {:50} {:>12} → {:>8} ({:.0}×)", name, orig, comp, ratio); }), - ).expect("GGUF indexing failed"); + ) + .expect("GGUF indexing failed"); drop(writer); let out_size = std::fs::metadata(out_path).map(|m| m.len()).unwrap_or(0); - eprintln!(" → {:.2} MB, {} tensors, {:.0}×", - out_size as f64 / 1e6, stats.tensors_indexed, stats.overall_ratio()); + eprintln!( + " → {:.2} MB, {} tensors, {:.0}×", + out_size as f64 / 1e6, + stats.tensors_indexed, + stats.overall_ratio() + ); } } diff --git a/src/hpc/simd_dispatch.rs b/src/hpc/simd_dispatch.rs index 7b456b9a..122f42cd 100644 --- a/src/hpc/simd_dispatch.rs +++ b/src/hpc/simd_dispatch.rs @@ -23,8 +23,8 @@ //! //! On wasm32 (future): tier would be WASM SIMD (128-bit, `+simd128`). -use std::sync::LazyLock; use super::simd_caps::simd_caps; +use std::sync::LazyLock; /// The selected SIMD tier, frozen at first access. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -159,10 +159,7 @@ impl SimdDispatch { // will be wired when simd_neon.rs types are activated. For now, // dispatch to scalar which auto-vectorizes well on aarch64 with // `-C target-feature=+neon` (mandatory on aarch64). - Self { - tier, - ..Self::scalar() - } + Self { tier, ..Self::scalar() } } #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] @@ -194,7 +191,9 @@ impl SimdDispatch { // ── byte_scan wrappers ── fn byte_find_all_scalar(haystack: &[u8], needle: u8) -> Vec { - haystack.iter().enumerate() + haystack + .iter() + .enumerate() .filter(|(_, &b)| b == needle) .map(|(i, _)| i) .collect() @@ -231,12 +230,15 @@ fn byte_count_avx2_wrapper(haystack: &[u8], needle: u8) -> usize { // ── distance wrappers ── fn squared_distances_scalar(query: [f32; 3], points: &[[f32; 3]]) -> Vec { - points.iter().map(|p| { - let dx = query[0] - p[0]; - let dy = query[1] - p[1]; - let dz = query[2] - p[2]; - dx * dx + dy * dy + dz * dz - }).collect() + points + .iter() + .map(|p| { + let dx = query[0] - p[0]; + let dy = query[1] - p[1]; + let dz = query[2] - p[2]; + dx * dx + dy * dy + dz * dz + }) + .collect() } #[cfg(target_arch = "x86_64")] diff --git a/src/hpc/spatial_hash.rs b/src/hpc/spatial_hash.rs index b69d125a..ae4303a2 100644 --- a/src/hpc/spatial_hash.rs +++ b/src/hpc/spatial_hash.rs @@ -98,12 +98,7 @@ impl SpatialHash { /// in `positions` are considered. Returns `(entity_id, squared_distance)` /// sorted ascending by distance. pub fn query_radius( - &self, - x: f32, - y: f32, - z: f32, - radius: f32, - positions: &HashMap, + &self, x: f32, y: f32, z: f32, radius: f32, positions: &HashMap, ) -> Vec<(u32, f32)> { let radius_sq = radius * radius; let query = [x, y, z]; @@ -144,14 +139,7 @@ impl SpatialHash { /// Uses expanding-ring search: starts at the cell containing the query /// point and expands outward until at least K candidates are found, then /// refines. Returns `(entity_id, squared_distance)` sorted ascending. - pub fn query_knn( - &self, - x: f32, - y: f32, - z: f32, - k: usize, - positions: &HashMap, - ) -> Vec<(u32, f32)> { + pub fn query_knn(&self, x: f32, y: f32, z: f32, k: usize, positions: &HashMap) -> Vec<(u32, f32)> { if k == 0 { return Vec::new(); } @@ -196,9 +184,7 @@ impl SpatialHash { // is closer than the nearest possible point in the next ring. // If so, no further ring can improve the result. if candidates.len() >= k { - candidates.sort_by(|a, b| { - a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal) - }); + candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); let worst = candidates[k - 1].1; // The nearest point in ring+1 is at least (ring * cell_size) away. let next_ring_min = (ring as f32) * self.cell_size; @@ -221,12 +207,7 @@ impl SpatialHash { /// Collects candidate positions from relevant cells, then uses SIMD to /// compute squared distances and filter in bulk. pub fn query_radius_simd( - &self, - x: f32, - y: f32, - z: f32, - radius: f32, - positions: &HashMap, + &self, x: f32, y: f32, z: f32, radius: f32, positions: &HashMap, ) -> Vec<(u32, f32)> { let radius_sq = radius * radius; let query = [x, y, z]; @@ -286,11 +267,7 @@ impl SpatialHash { /// Batch squared-distance filter: compute squared distances from `query` to /// each position in `candidates`, returning `(index, sq_dist)` for entries /// within `radius_sq`. -fn batch_sq_dist_filter( - query: [f32; 3], - candidates: &[[f32; 3]], - radius_sq: f32, -) -> Vec<(usize, f32)> { +fn batch_sq_dist_filter(query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32) -> Vec<(usize, f32)> { #[cfg(target_arch = "x86_64")] { if candidates.len() >= 8 && super::simd_caps::simd_caps().avx2 { @@ -301,11 +278,7 @@ fn batch_sq_dist_filter( batch_sq_dist_scalar(query, candidates, radius_sq) } -pub(crate) fn batch_sq_dist_scalar( - query: [f32; 3], - candidates: &[[f32; 3]], - radius_sq: f32, -) -> Vec<(usize, f32)> { +pub(crate) fn batch_sq_dist_scalar(query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32) -> Vec<(usize, f32)> { let mut result = Vec::new(); for (i, pos) in candidates.iter().enumerate() { let d2 = sq_dist_f32(query, *pos); @@ -323,11 +296,7 @@ pub(crate) fn batch_sq_dist_scalar( /// Caller must ensure AVX2 is available. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -pub(crate) unsafe fn batch_sq_dist_avx2( - query: [f32; 3], - candidates: &[[f32; 3]], - radius_sq: f32, -) -> Vec<(usize, f32)> { +pub(crate) unsafe fn batch_sq_dist_avx2(query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32) -> Vec<(usize, f32)> { use crate::simd::F32x8; let mut result = Vec::new(); @@ -460,12 +429,7 @@ mod tests { #[test] fn test_query_radius_basic() { let mut sh = SpatialHash::new(10.0); - let pts = vec![ - (0u32, [0.0f32, 0.0, 0.0]), - (1, [5.0, 0.0, 0.0]), - (2, [20.0, 0.0, 0.0]), - (3, [100.0, 0.0, 0.0]), - ]; + let pts = vec![(0u32, [0.0f32, 0.0, 0.0]), (1, [5.0, 0.0, 0.0]), (2, [20.0, 0.0, 0.0]), (3, [100.0, 0.0, 0.0])]; for &(id, pos) in &pts { sh.insert(id, pos[0], pos[1], pos[2]); } @@ -481,11 +445,7 @@ mod tests { #[test] fn test_query_radius_sorted_by_distance() { let mut sh = SpatialHash::new(5.0); - let pts = vec![ - (0u32, [10.0f32, 0.0, 0.0]), - (1, [3.0, 0.0, 0.0]), - (2, [1.0, 0.0, 0.0]), - ]; + let pts = vec![(0u32, [10.0f32, 0.0, 0.0]), (1, [3.0, 0.0, 0.0]), (2, [1.0, 0.0, 0.0])]; for &(id, pos) in &pts { sh.insert(id, pos[0], pos[1], pos[2]); } @@ -510,12 +470,7 @@ mod tests { #[test] fn test_knn_basic() { let mut sh = SpatialHash::new(10.0); - let pts = vec![ - (0u32, [30.0f32, 0.0, 0.0]), - (1, [10.0, 0.0, 0.0]), - (2, [20.0, 0.0, 0.0]), - (3, [5.0, 0.0, 0.0]), - ]; + let pts = vec![(0u32, [30.0f32, 0.0, 0.0]), (1, [10.0, 0.0, 0.0]), (2, [20.0, 0.0, 0.0]), (3, [5.0, 0.0, 0.0])]; for &(id, pos) in &pts { sh.insert(id, pos[0], pos[1], pos[2]); } @@ -556,7 +511,10 @@ mod tests { assert!( (r.1 - b.1).abs() < 1e-3, "knn dist mismatch: spatial_hash=({},{:.2}) brute=({},{:.2})", - r.0, r.1, b.0, b.1 + r.0, + r.1, + b.0, + b.1 ); } } @@ -583,10 +541,7 @@ mod tests { #[test] fn test_negative_coordinates() { let mut sh = SpatialHash::new(10.0); - let pts = vec![ - (0u32, [-5.0f32, -5.0, -5.0]), - (1, [5.0, 5.0, 5.0]), - ]; + let pts = vec![(0u32, [-5.0f32, -5.0, -5.0]), (1, [5.0, 5.0, 5.0])]; for &(id, pos) in &pts { sh.insert(id, pos[0], pos[1], pos[2]); } @@ -606,12 +561,7 @@ mod tests { #[test] fn test_query_radius_simd_basic() { let mut sh = SpatialHash::new(10.0); - let pts = vec![ - (0u32, [0.0f32, 0.0, 0.0]), - (1, [5.0, 0.0, 0.0]), - (2, [20.0, 0.0, 0.0]), - (3, [100.0, 0.0, 0.0]), - ]; + let pts = vec![(0u32, [0.0f32, 0.0, 0.0]), (1, [5.0, 0.0, 0.0]), (2, [20.0, 0.0, 0.0]), (3, [100.0, 0.0, 0.0])]; for &(id, pos) in &pts { sh.insert(id, pos[0], pos[1], pos[2]); } @@ -645,12 +595,7 @@ mod tests { assert_eq!(scalar.len(), simd.len(), "result count mismatch"); for (s, r) in scalar.iter().zip(simd.iter()) { assert_eq!(s.0, r.0, "id mismatch"); - assert!( - (s.1 - r.1).abs() < 1e-3, - "distance mismatch: scalar={:.4} simd={:.4}", - s.1, - r.1 - ); + assert!((s.1 - r.1).abs() < 1e-3, "distance mismatch: scalar={:.4} simd={:.4}", s.1, r.1); } } diff --git a/src/hpc/spo_bundle.rs b/src/hpc/spo_bundle.rs index 59a46e42..2b2f6665 100644 --- a/src/hpc/spo_bundle.rs +++ b/src/hpc/spo_bundle.rs @@ -18,7 +18,11 @@ const PHI: f64 = std::f64::consts::GOLDEN_RATIO; /// floor(d / φ²), rounded to nearest odd. pub const fn golden_shift(d: usize) -> usize { let raw = (d as f64 / (PHI * PHI)) as usize; - if raw % 2 == 0 { raw + 1 } else { raw } + if raw % 2 == 0 { + raw + 1 + } else { + raw + } } /// Level A constants (8Kbit = 128 × u64) @@ -60,11 +64,7 @@ pub fn cyclic_shift(bits: &[u64; N], shift: usize) -> [u64; N] { /// Majority vote of 3 binary vectors: output bit = 1 if ≥2 of 3 inputs are 1. /// Bit-parallel: (a&b) | (a&c) | (b&c) -pub fn majority_vote_3( - a: &[u64; N], - b: &[u64; N], - c: &[u64; N], -) -> [u64; N] { +pub fn majority_vote_3(a: &[u64; N], b: &[u64; N], c: &[u64; N]) -> [u64; N] { let mut result = [0u64; N]; for i in 0..N { result[i] = (a[i] & b[i]) | (a[i] & c[i]) | (b[i] & c[i]); @@ -162,7 +162,9 @@ pub fn recover_o_16k(bundle: &[u64; 256]) -> [u64; 256] { /// Simple deterministic PRNG for reproducible tests. fn prng_next(state: &mut u64) -> u64 { - *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + *state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *state } @@ -271,19 +273,11 @@ mod tests { // FIX: golden_shift(8192) = 3129, which is odd → gcd(3129, 8192) = 1 assert_eq!(gcd(SHIFT_META, D_META), 1, "meta shift must be coprime"); - assert!( - SHIFT_META % 2 == 1, - "shift must be odd, got {}", - SHIFT_META - ); + assert!(SHIFT_META % 2 == 1, "shift must be odd, got {}", SHIFT_META); // FIX: golden_shift(16384) is odd → gcd with 16384 = 1 assert_eq!(gcd(SHIFT_FULL, D_FULL), 1, "full shift must be coprime"); - assert!( - SHIFT_FULL % 2 == 1, - "full shift must be odd, got {}", - SHIFT_FULL - ); + assert!(SHIFT_FULL % 2 == 1, "full shift must be odd, got {}", SHIFT_FULL); // Verify full orbit at d=8192: shift visits all positions let mut visited = vec![false; D_META]; @@ -320,30 +314,16 @@ mod tests { let dist_odd = hamming(&alternating, &shifted_odd); eprintln!("\n EXPERIMENT 1: GCD Verification"); - eprintln!( - " shift=3130 (even): hamming(alternating, shift) = {} (VULNERABLE: same pattern)", - dist_even - ); + eprintln!(" shift=3130 (even): hamming(alternating, shift) = {} (VULNERABLE: same pattern)", dist_even); eprintln!( " shift={} (odd): hamming(alternating, shift) = {} (SAFE: fully decorrelated)", SHIFT_META, dist_odd ); eprintln!(" gcd(3130, 8192) = {} → orbit len = {}", gcd(3130, 8192), orbit_3130); - eprintln!( - " gcd({}, 8192) = {} → orbit len = {}", - SHIFT_META, - gcd(SHIFT_META, D_META), - orbit_good - ); + eprintln!(" gcd({}, 8192) = {} → orbit len = {}", SHIFT_META, gcd(SHIFT_META, D_META), orbit_good); - assert_eq!( - dist_even, 0, - "even shift on alternating should give hamming=0" - ); - assert_eq!( - dist_odd, D_META as u32, - "odd shift on alternating should give hamming=d" - ); + assert_eq!(dist_even, 0, "even shift on alternating should give hamming=0"); + assert_eq!(dist_odd, D_META as u32, "odd shift on alternating should give hamming=d"); } // ======================================================================== @@ -386,21 +366,9 @@ mod tests { eprintln!(" P error rate: {:.4} (expected 0.25)", p_err); eprintln!(" O error rate: {:.4} (expected 0.25)", o_err); - assert!( - (s_err - 0.25).abs() < 0.02, - "S error rate {:.4} too far from 0.25", - s_err - ); - assert!( - (p_err - 0.25).abs() < 0.02, - "P error rate {:.4} too far from 0.25", - p_err - ); - assert!( - (o_err - 0.25).abs() < 0.02, - "O error rate {:.4} too far from 0.25", - o_err - ); + assert!((s_err - 0.25).abs() < 0.02, "S error rate {:.4} too far from 0.25", s_err); + assert!((p_err - 0.25).abs() < 0.02, "P error rate {:.4} too far from 0.25", p_err); + assert!((o_err - 0.25).abs() < 0.02, "O error rate {:.4} too far from 0.25", o_err); } #[test] @@ -424,11 +392,7 @@ mod tests { eprintln!("\n EXPERIMENT 2b: Recovery Rate (16Kbit, {} trials)", n_trials); eprintln!(" S error rate: {:.4} (expected 0.25)", err); - assert!( - (err - 0.25).abs() < 0.02, - "16K error rate {:.4} too far from 0.25", - err - ); + assert!((err - 0.25).abs() < 0.02, "16K error rate {:.4} too far from 0.25", err); } // ======================================================================== @@ -483,10 +447,8 @@ mod tests { // Recall@k for &k in &[1, 5, 10, 20] { - let top_sep: HashSet = - dists_sep[..k].iter().map(|&(i, _)| i).collect(); - let top_bun: HashSet = - dists_bun[..k].iter().map(|&(i, _)| i).collect(); + let top_sep: HashSet = dists_sep[..k].iter().map(|&(i, _)| i).collect(); + let top_bun: HashSet = dists_bun[..k].iter().map(|&(i, _)| i).collect(); let recall = top_sep.intersection(&top_bun).count() as f64 / k as f64; eprintln!(" Query {}: Recall@{} = {:.2}", qi, k, recall); } @@ -546,13 +508,8 @@ mod tests { } let slope = if var_t > 0.0 { cov / var_t } else { 0.0 }; let intercept = mean_b - slope * mean_t; - eprintln!( - " Linear fit: bundle_dist = {:.4} × true_dist + {:.1}", - slope, intercept - ); - eprintln!( - " Expected: bundle_dist = ~0.167 × true_dist + ~3072" - ); + eprintln!(" Linear fit: bundle_dist = {:.4} × true_dist + {:.1}", slope, intercept); + eprintln!(" Expected: bundle_dist = ~0.167 × true_dist + ~3072"); // For 3-component sum: the contraction is on the combined distance // which has range [0, 3d]. The slope should be ~1/6 and intercept ~3d×0.375/3 @@ -562,11 +519,7 @@ mod tests { // Majority vote creates cross-component interference that compresses // the distance range. Still useful as cascade stroke — monotonic ranking. // THRESHOLD: Spearman > 0.70 (adjusted for 3-component mixing) - assert!( - mean_rho > 0.70, - "Mean Spearman ρ = {:.4} < 0.70 threshold", - mean_rho - ); + assert!(mean_rho > 0.70, "Mean Spearman ρ = {:.4} < 0.70 threshold", mean_rho); } // ======================================================================== @@ -635,16 +588,10 @@ mod tests { max_auto = max_auto.max(auto); min_auto = min_auto.min(auto); if i < 5 { - eprintln!( - " Text {} autocorrelation at lag {}: {:.4} (expect ~0.50)", - i, SHIFT_FULL, auto - ); + eprintln!(" Text {} autocorrelation at lag {}: {:.4} (expect ~0.50)", i, SHIFT_FULL, auto); } } - eprintln!( - " Autocorrelation range: [{:.4}, {:.4}]", - min_auto, max_auto - ); + eprintln!(" Autocorrelation range: [{:.4}, {:.4}]", min_auto, max_auto); // Bundle and recover — measure error rate on structured data let n = planes.len(); @@ -663,22 +610,11 @@ mod tests { total_bits += D_FULL as u64; } let error_rate = total_errors as f64 / total_bits as f64; - eprintln!( - " Structured recovery error: {:.4} (expected ~0.25)", - error_rate - ); + eprintln!(" Structured recovery error: {:.4} (expected ~0.25)", error_rate); // Flag if autocorrelation is outside safe range - assert!( - min_auto > 0.40, - "Autocorrelation too low: {:.4} (periodic structure detected)", - min_auto - ); - assert!( - max_auto < 0.60, - "Autocorrelation too high: {:.4}", - max_auto - ); + assert!(min_auto > 0.40, "Autocorrelation too low: {:.4} (periodic structure detected)", min_auto); + assert!(max_auto < 0.60, "Autocorrelation too high: {:.4}", max_auto); assert!( (error_rate - 0.25).abs() < 0.05, "Structured error rate {:.4} deviates from expected 0.25", @@ -743,22 +679,12 @@ mod tests { } let total = nodes.len(); - let integrated: Vec<[u64; 256]> = nodes - .iter() - .map(|(s, p, o)| bundle_16k(s, p, o)) - .collect(); + let integrated: Vec<[u64; 256]> = nodes.iter().map(|(s, p, o)| bundle_16k(s, p, o)).collect(); - eprintln!( - "\n EXPERIMENT 5: Holographic Resonance (n={}, 4 clusters + {} random)", - total, n_random - ); + eprintln!("\n EXPERIMENT 5: Holographic Resonance (n={}, 4 clusters + {} random)", total, n_random); // Query from cluster 3 (similar S+P) - for &qi in &[ - n_per_cluster * 2, - n_per_cluster * 2 + 10, - n_per_cluster * 2 + 20, - ] { + for &qi in &[n_per_cluster * 2, n_per_cluster * 2 + 10, n_per_cluster * 2 + 20] { let mut dists_sep: Vec<(usize, u32)> = (0..total) .filter(|&i| i != qi) .map(|i| { @@ -786,10 +712,7 @@ mod tests { .filter(|&&(i, _)| labels[i] == 3) .count(); - eprintln!( - " Query {} (cluster 3): sep purity={}/{} int purity={}/{}", - qi, purity_sep, k, purity_int, k - ); + eprintln!(" Query {} (cluster 3): sep purity={}/{} int purity={}/{}", qi, purity_sep, k, purity_int, k); } // Also test: do integrated and separate produce same top-20? @@ -832,14 +755,8 @@ mod tests { }) .collect(); - let bundles_8k: Vec<[u64; 128]> = nodes - .iter() - .map(|(s, p, o)| bundle_8k(s, p, o)) - .collect(); - let integrated_16k: Vec<[u64; 256]> = nodes - .iter() - .map(|(s, p, o)| bundle_16k(s, p, o)) - .collect(); + let bundles_8k: Vec<[u64; 128]> = nodes.iter().map(|(s, p, o)| bundle_8k(s, p, o)).collect(); + let integrated_16k: Vec<[u64; 256]> = nodes.iter().map(|(s, p, o)| bundle_16k(s, p, o)).collect(); eprintln!("\n EXPERIMENT 6: Cascade Coherence (n={})", n_nodes); @@ -871,8 +788,12 @@ mod tests { eprintln!(" ρ(8K bundle → exact): {:.4}", rho_13); // Cascade simulation: 500 → top 50 by 8K → top 10 by 16K → verify exact - let mut l1_ranked: Vec<(usize, f64)> = - l1.iter().copied().enumerate().map(|(i, d)| (i + 1, d)).collect(); + let mut l1_ranked: Vec<(usize, f64)> = l1 + .iter() + .copied() + .enumerate() + .map(|(i, d)| (i + 1, d)) + .collect(); l1_ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let l1_top50: Vec = l1_ranked[..50].iter().map(|&(i, _)| i).collect(); @@ -884,22 +805,19 @@ mod tests { let l2_top10: HashSet = l2_filtered[..10].iter().map(|&(i, _)| i).collect(); // Ground truth top-10 - let mut l3_ranked: Vec<(usize, f64)> = - l3.iter().copied().enumerate().map(|(i, d)| (i + 1, d)).collect(); + let mut l3_ranked: Vec<(usize, f64)> = l3 + .iter() + .copied() + .enumerate() + .map(|(i, d)| (i + 1, d)) + .collect(); l3_ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let gt_top10: HashSet = l3_ranked[..10].iter().map(|&(i, _)| i).collect(); let cascade_recall = gt_top10.intersection(&l2_top10).count() as f64 / 10.0; - eprintln!( - "\n Cascade recall@10 (500→50→10): {:.2}", - cascade_recall - ); + eprintln!("\n Cascade recall@10 (500→50→10): {:.2}", cascade_recall); - assert!( - rho_23 > 0.30, - "16K→exact ρ = {:.4} < 0.30 threshold", - rho_23 - ); + assert!(rho_23 > 0.30, "16K→exact ρ = {:.4} < 0.30 threshold", rho_23); } // ======================================================================== @@ -947,10 +865,7 @@ mod tests { let int_tree = ClamTree::build(&integrated_bytes, 2048, 5); let content_tree = ClamTree::build(&content_bytes, 2048, 5); - eprintln!( - "\n EXPERIMENT 7: CLAM Clustering ({} clusters × {} nodes)", - n_clusters, n_per_cluster - ); + eprintln!("\n EXPERIMENT 7: CLAM Clustering ({} clusters × {} nodes)", n_clusters, n_per_cluster); eprintln!(" Integrated tree: {} nodes", int_tree.nodes.len()); eprintln!(" Content tree: {} nodes", content_tree.nodes.len()); @@ -1004,10 +919,7 @@ mod tests { let n_trials = 100; eprintln!("\n EXPERIMENT 8: Bias Resilience"); - eprintln!( - " {:>6} {:>12} {:>12} {:>12}", - "Bias", "Actual Err", "Predicted", "P(correct)" - ); + eprintln!(" {:>6} {:>12} {:>12} {:>12}", "Bias", "Actual Err", "Predicted", "P(correct)"); for bias_pct in [30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80] { let bias = bias_pct as f64 / 100.0; @@ -1028,10 +940,7 @@ mod tests { let predicted = bias * (1.0 - bias); // P(error) = p(1-p) let p_correct = 1.0 - predicted; - eprintln!( - " {:.2} {:.4} {:.4} {:.4}", - bias, actual_err, predicted, p_correct - ); + eprintln!(" {:.2} {:.4} {:.4} {:.4}", bias, actual_err, predicted, p_correct); } } @@ -1050,10 +959,7 @@ mod tests { }) .collect(); - let bundles: Vec<[u64; 256]> = nodes - .iter() - .map(|(s, p, o)| bundle_16k(s, p, o)) - .collect(); + let bundles: Vec<[u64; 256]> = nodes.iter().map(|(s, p, o)| bundle_16k(s, p, o)).collect(); eprintln!("\n EXPERIMENT 9: Multi-Hop Query (n={})", n_nodes); @@ -1088,11 +994,7 @@ mod tests { let rho = spearman(&hop_f64, &exact_f64); eprintln!(" Multi-hop Spearman ρ: {:.4}", rho); - assert!( - rho > 0.50, - "Multi-hop ρ = {:.4} too low (error propagation)", - rho - ); + assert!(rho > 0.50, "Multi-hop ρ = {:.4} too low (error propagation)", rho); } // ======================================================================== @@ -1101,11 +1003,7 @@ mod tests { #[test] fn exp10_accumulator_capacity() { - let target: [u64; 256] = bundle_16k( - &random_bits(42), - &random_bits(43), - &random_bits(44), - ); + let target: [u64; 256] = bundle_16k(&random_bits(42), &random_bits(43), &random_bits(44)); let mut acc = vec![0i32; D_FULL]; @@ -1144,13 +1042,7 @@ mod tests { let snr = (D_FULL as f64 / (std::f64::consts::PI * (n + 1) as f64)).sqrt(); if n <= 10 || n % 50 == 0 || error_rate > 0.40 { - eprintln!( - " {:>6} {:>8} {:>7.1}% {:>8.1}", - n, - dist, - error_rate * 100.0, - snr - ); + eprintln!(" {:>6} {:>8} {:>7.1}% {:>8.1}", n, dist, error_rate * 100.0, snr); } if error_rate > 0.45 && capacity_limit == 0 { @@ -1299,8 +1191,13 @@ mod tests { let po_close = ((dp + d_o) < 2 * thresh) as u8; let spo_close = ((ds + dp + d_o) < 3 * thresh) as u8; - s_close | (p_close << 1) | (o_close << 2) | (sp_close << 3) - | (so_close << 4) | (po_close << 5) | (spo_close << 6) + s_close + | (p_close << 1) + | (o_close << 2) + | (sp_close << 3) + | (so_close << 4) + | (po_close << 5) + | (spo_close << 6) } /// ZeckF64: 8 bytes = scent + 7 resolution quantiles. @@ -1314,8 +1211,14 @@ mod tests { let byte6 = (dp as u64 * 255 / d_max as u64).min(255); let byte7 = (ds as u64 * 255 / d_max as u64).min(255); - byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) - | (byte4 << 32) | (byte5 << 40) | (byte6 << 48) | (byte7 << 56) + byte0 + | (byte1 << 8) + | (byte2 << 16) + | (byte3 << 24) + | (byte4 << 32) + | (byte5 << 40) + | (byte6 << 48) + | (byte7 << 56) } fn zeckf64_l1(a: u64, b: u64) -> u32 { @@ -1353,13 +1256,11 @@ mod tests { let mut bundle_16k_dists = Vec::with_capacity(n_pairs); // Pre-build bundles - let bundles_8k: Vec<[u64; 128]> = nodes.iter() - .map(|(s, p, o)| bundle_8k(s, p, o)).collect(); - let bundles_16k: Vec<[u64; 256]> = nodes.iter() - .map(|(s, p, o)| bundle_16k(s, p, o)).collect(); + let bundles_8k: Vec<[u64; 128]> = nodes.iter().map(|(s, p, o)| bundle_8k(s, p, o)).collect(); + let bundles_16k: Vec<[u64; 256]> = nodes.iter().map(|(s, p, o)| bundle_16k(s, p, o)).collect(); for i in 0..n_nodes { - for j in (i+1)..n_nodes { + for j in (i + 1)..n_nodes { let ds = hamming(&nodes[i].0, &nodes[j].0); let dp = hamming(&nodes[i].1, &nodes[j].1); let d_o = hamming(&nodes[i].2, &nodes[j].2); @@ -1415,12 +1316,21 @@ mod tests { println!(" ─────────────────────────────────────────────────────"); println!(" Method Bits Spearman ρ Verdict"); println!(" ─────────────────────────────────────────────────────"); - println!(" ZeckF64 (8 bytes) 64 {:.4} {}", rho_z64, - if rho_z64 > 0.90 {"GO ✓"} else {"CHECK"}); - println!(" Bundle 16K (maj3) 16,384 {:.4} {}", rho_b16k, - if rho_b16k > 0.80 {"GO ✓"} else {"DEAD ZONE"}); - println!(" Bundle 8K (fold+maj) 8,192 {:.4} {}", rho_b8k, - if rho_b8k > 0.60 {"GO ✓"} else {"DEAD ZONE"}); + println!( + " ZeckF64 (8 bytes) 64 {:.4} {}", + rho_z64, + if rho_z64 > 0.90 { "GO ✓" } else { "CHECK" } + ); + println!( + " Bundle 16K (maj3) 16,384 {:.4} {}", + rho_b16k, + if rho_b16k > 0.80 { "GO ✓" } else { "DEAD ZONE" } + ); + println!( + " Bundle 8K (fold+maj) 8,192 {:.4} {}", + rho_b8k, + if rho_b8k > 0.60 { "GO ✓" } else { "DEAD ZONE" } + ); println!(" Exact S+P+O 49,152 1.0000 reference"); println!(" ─────────────────────────────────────────────────────"); @@ -1498,8 +1408,12 @@ mod tests { let top_zf64: HashSet = zf64_full[..k].iter().map(|&(i, _)| i).collect(); let recall = top_exact.intersection(&top_zf64).count() as f64 / k as f64; - if k == 1 { all_recall_1.push(recall); } - if k == 10 { all_recall_10.push(recall); } + if k == 1 { + all_recall_1.push(recall); + } + if k == 10 { + all_recall_10.push(recall); + } } } @@ -1507,8 +1421,7 @@ mod tests { let mean_r10 = all_recall_10.iter().sum::() / all_recall_10.len() as f64; println!(" ZeckF64 Recall@1: {:.3}", mean_r1); println!(" ZeckF64 Recall@10: {:.3}", mean_r10); - println!(" Recall@1 > 0.80: {}", if mean_r1 > 0.80 {"GO ✓"} else {"CHECK"}); - println!(" Recall@10 > 0.70: {}", if mean_r10 > 0.70 {"GO ✓"} else {"CHECK"}); + println!(" Recall@1 > 0.80: {}", if mean_r1 > 0.80 { "GO ✓" } else { "CHECK" }); + println!(" Recall@10 > 0.70: {}", if mean_r10 > 0.70 { "GO ✓" } else { "CHECK" }); } - } diff --git a/src/hpc/stable_diffusion/api.rs b/src/hpc/stable_diffusion/api.rs index fa890c73..19f029d2 100644 --- a/src/hpc/stable_diffusion/api.rs +++ b/src/hpc/stable_diffusion/api.rs @@ -2,11 +2,11 @@ //! //! Endpoint: `/v1/images/generations` -use crate::hpc::models::api_types::*; use super::clip::ClipEncoder; use super::scheduler::{DdimScheduler, SchedulerConfig}; use super::unet; use super::vae; +use crate::hpc::models::api_types::*; /// Stable Diffusion API wrapper. pub struct StableDiffusionApi { @@ -16,7 +16,10 @@ pub struct StableDiffusionApi { impl StableDiffusionApi { pub fn new(clip: ClipEncoder) -> Self { - Self { clip, scheduler_config: SchedulerConfig::default() } + Self { + clip, + scheduler_config: SchedulerConfig::default(), + } } /// `/v1/images/generations` @@ -56,7 +59,10 @@ impl StableDiffusionApi { }); } - ImageResponse { created: 0, data: images } + ImageResponse { + created: 0, + data: images, + } } /// `/v1/models/{id}` @@ -82,8 +88,8 @@ fn base64_placeholder(rgb: &[u8]) -> String { #[cfg(test)] mod tests { + use super::super::clip::{ClipWeights, CLIP_EMBED_DIM, CLIP_MAX_SEQ, CLIP_VOCAB_SIZE}; use super::*; - use super::super::clip::{ClipWeights, CLIP_VOCAB_SIZE, CLIP_EMBED_DIM, CLIP_MAX_SEQ}; fn dummy_clip() -> ClipEncoder { ClipEncoder::new(ClipWeights { diff --git a/src/hpc/stable_diffusion/clip.rs b/src/hpc/stable_diffusion/clip.rs index f238e923..e0f24595 100644 --- a/src/hpc/stable_diffusion/clip.rs +++ b/src/hpc/stable_diffusion/clip.rs @@ -70,8 +70,7 @@ impl ClipEncoder { let hid_off = t * CLIP_EMBED_DIM; for d in 0..CLIP_EMBED_DIM { hidden[hid_off + d] = - self.weights.token_embedding[tok_off + d] - + self.weights.position_embedding[pos_off + d]; + self.weights.token_embedding[tok_off + d] + self.weights.position_embedding[pos_off + d]; } } @@ -94,12 +93,7 @@ impl ClipEncoder { } /// One transformer layer (bidirectional self-attention + MLP). - fn transformer_layer( - &self, - layer: &ClipLayerWeights, - hidden: &mut [f32], - seq_len: usize, - ) { + fn transformer_layer(&self, layer: &ClipLayerWeights, hidden: &mut [f32], seq_len: usize) { // Process each position through attention + MLP // For the scaffold: simplified single-token path. // Full implementation would do batched multi-head attention. @@ -111,8 +105,7 @@ impl ClipEncoder { // Self-attention (simplified: each position attends to itself for scaffold) let mut attn_out = vec![0.0f32; CLIP_EMBED_DIM]; layers::matmul_vec( - &normed, &layer.attn_out_weight, &layer.attn_out_bias, - &mut attn_out, CLIP_EMBED_DIM, CLIP_EMBED_DIM, + &normed, &layer.attn_out_weight, &layer.attn_out_bias, &mut attn_out, CLIP_EMBED_DIM, CLIP_EMBED_DIM, ); // Residual @@ -125,11 +118,15 @@ impl ClipEncoder { layers::layer_norm(&mut normed2, &layer.ln2_weight, &layer.ln2_bias); let mut fc_out = vec![0.0f32; CLIP_MLP_DIM]; - layers::matmul_vec(&normed2, &layer.mlp_fc_weight, &layer.mlp_fc_bias, &mut fc_out, CLIP_EMBED_DIM, CLIP_MLP_DIM); + layers::matmul_vec( + &normed2, &layer.mlp_fc_weight, &layer.mlp_fc_bias, &mut fc_out, CLIP_EMBED_DIM, CLIP_MLP_DIM, + ); layers::gelu(&mut fc_out); let mut proj_out = vec![0.0f32; CLIP_EMBED_DIM]; - layers::matmul_vec(&fc_out, &layer.mlp_proj_weight, &layer.mlp_proj_bias, &mut proj_out, CLIP_MLP_DIM, CLIP_EMBED_DIM); + layers::matmul_vec( + &fc_out, &layer.mlp_proj_weight, &layer.mlp_proj_bias, &mut proj_out, CLIP_MLP_DIM, CLIP_EMBED_DIM, + ); // Residual for d in 0..CLIP_EMBED_DIM { diff --git a/src/hpc/stable_diffusion/scheduler.rs b/src/hpc/stable_diffusion/scheduler.rs index b40cce4e..cd0e3971 100644 --- a/src/hpc/stable_diffusion/scheduler.rs +++ b/src/hpc/stable_diffusion/scheduler.rs @@ -44,9 +44,7 @@ impl DdimScheduler { // Linear beta schedule let betas: Vec = (0..n) - .map(|i| { - config.beta_start + (config.beta_end - config.beta_start) * i as f32 / (n - 1) as f32 - }) + .map(|i| config.beta_start + (config.beta_end - config.beta_start) * i as f32 / (n - 1) as f32) .collect(); // Alphas = 1 - beta @@ -67,7 +65,11 @@ impl DdimScheduler { .map(|i| i * step_size) .collect(); - Self { config, alphas_cumprod, timesteps } + Self { + config, + alphas_cumprod, + timesteps, + } } /// Single denoising step: given model noise prediction, compute x_{t-1} from x_t. @@ -98,9 +100,11 @@ impl DdimScheduler { let sqrt_alpha = alpha.sqrt(); let sqrt_one_minus = (1.0 - alpha).sqrt(); - original.iter().zip(noise).map(|(&x, &n)| { - sqrt_alpha * x + sqrt_one_minus * n - }).collect() + original + .iter() + .zip(noise) + .map(|(&x, &n)| sqrt_alpha * x + sqrt_one_minus * n) + .collect() } /// Number of inference steps. diff --git a/src/hpc/stable_diffusion/unet.rs b/src/hpc/stable_diffusion/unet.rs index b955c545..8bbf5a2e 100644 --- a/src/hpc/stable_diffusion/unet.rs +++ b/src/hpc/stable_diffusion/unet.rs @@ -42,13 +42,7 @@ pub fn timestep_embedding(timestep: f32, dim: usize) -> Vec { /// Operates on [channels, height, width] layout. /// Minimal implementation — no dilation, no groups beyond depthwise. pub fn conv2d_3x3( - input: &[f32], - weight: &[f32], - bias: &[f32], - in_channels: usize, - out_channels: usize, - h: usize, - w: usize, + input: &[f32], weight: &[f32], bias: &[f32], in_channels: usize, out_channels: usize, h: usize, w: usize, ) -> Vec { let mut output = vec![0.0f32; out_channels * h * w]; @@ -127,11 +121,7 @@ impl ResBlockWeights { /// /// This is the scaffold — full implementation would chain: /// down_blocks → mid_block → up_blocks with skip connections. -pub fn predict_noise( - noisy_latent: &[f32], - text_embeddings: &[f32], - timestep: f32, -) -> Vec { +pub fn predict_noise(noisy_latent: &[f32], text_embeddings: &[f32], timestep: f32) -> Vec { let _t_emb = timestep_embedding(timestep, MODEL_CHANNELS); // Scaffold: return zero noise prediction (actual UNet weights needed) vec![0.0f32; noisy_latent.len()] diff --git a/src/hpc/stable_diffusion/vae.rs b/src/hpc/stable_diffusion/vae.rs index c9a27ab9..8781fe31 100644 --- a/src/hpc/stable_diffusion/vae.rs +++ b/src/hpc/stable_diffusion/vae.rs @@ -41,8 +41,7 @@ pub fn decode(latent: &[f32], h: usize, w: usize) -> Vec { for ow in 0..out_w { let ih = oh / VAE_SCALE_FACTOR; let iw = ow / VAE_SCALE_FACTOR; - upsampled[c * out_h * out_w + oh * out_w + ow] = - scaled[c * h * w + ih * w + iw]; + upsampled[c * out_h * out_w + oh * out_w + ow] = scaled[c * h * w + ih * w + iw]; } } } @@ -94,16 +93,16 @@ mod tests { let tensor = vec![0.0f32, 0.5, 1.0, 0.0, 0.5, 1.0, 0.0, 0.5, 1.0]; // 3ch, 1×3 let rgb = to_rgb_u8(&tensor, 3, 1, 3); assert_eq!(rgb.len(), 9); - assert_eq!(rgb[0], 0); // R of pixel 0 - assert_eq!(rgb[1], 0); // G of pixel 0 - assert_eq!(rgb[2], 0); // B of pixel 0 + assert_eq!(rgb[0], 0); // R of pixel 0 + assert_eq!(rgb[1], 0); // G of pixel 0 + assert_eq!(rgb[2], 0); // B of pixel 0 } #[test] fn test_to_rgb_u8_clamp() { let tensor = vec![-1.0f32, 2.0, 0.5]; // 3ch, 1×1 let rgb = to_rgb_u8(&tensor, 3, 1, 1); - assert_eq!(rgb[0], 0); // clamped from -1 + assert_eq!(rgb[0], 0); // clamped from -1 assert_eq!(rgb[1], 255); // clamped from 2 assert_eq!(rgb[2], 127); // 0.5 * 255 = 127.5 → 127 } diff --git a/src/hpc/stable_diffusion/weights.rs b/src/hpc/stable_diffusion/weights.rs index 18eaf46a..93db5e08 100644 --- a/src/hpc/stable_diffusion/weights.rs +++ b/src/hpc/stable_diffusion/weights.rs @@ -6,8 +6,8 @@ //! No weights are stored in this crate — they're loaded at runtime //! from user-provided safetensors files (disk space conscious). -use crate::hpc::models::safetensors::{SafeTensorsFile, transpose_matrix}; use super::clip::*; +use crate::hpc::models::safetensors::{transpose_matrix, SafeTensorsFile}; /// Load CLIP text encoder weights from a safetensors file. /// diff --git a/src/hpc/statistics.rs b/src/hpc/statistics.rs index 4c430c16..14d97010 100644 --- a/src/hpc/statistics.rs +++ b/src/hpc/statistics.rs @@ -4,8 +4,8 @@ //! ported from rustynum. use crate::imp_prelude::*; +use core::ops::{Add, Div, Mul, Sub}; use num_traits::{Float, FromPrimitive, Zero}; -use core::ops::{Add, Div, Sub, Mul}; /// Statistical operations on arrays. /// @@ -64,8 +64,15 @@ pub trait Statistics { impl Statistics for ArrayBase where - A: Float + FromPrimitive + Zero + Add + Sub - + Mul + Div + PartialOrd + 'static, + A: Float + + FromPrimitive + + Zero + + Add + + Sub + + Mul + + Div + + PartialOrd + + 'static, S: Data, D: Dimension, { @@ -233,7 +240,10 @@ where } fn cosine_similarity(&self, other: &Self) -> A { - let dot: A = self.iter().zip(other.iter()).fold(A::zero(), |acc, (&a, &b)| acc + a * b); + let dot: A = self + .iter() + .zip(other.iter()) + .fold(A::zero(), |acc, (&a, &b)| acc + a * b); let norm_a: A = self.iter().fold(A::zero(), |acc, &v| acc + v * v).sqrt(); let norm_b: A = other.iter().fold(A::zero(), |acc, &v| acc + v * v).sqrt(); if norm_a == A::zero() || norm_b == A::zero() { @@ -249,12 +259,8 @@ where // L0 "norm": count of non-zero elements A::from_usize(self.iter().filter(|&&v| v != A::zero()).count()).unwrap() } - 1 => { - self.iter().fold(A::zero(), |acc, &v| acc + v.abs()) - } - 2 => { - self.iter().fold(A::zero(), |acc, &v| acc + v * v).sqrt() - } + 1 => self.iter().fold(A::zero(), |acc, &v| acc + v.abs()), + 2 => self.iter().fold(A::zero(), |acc, &v| acc + v * v).sqrt(), _ => { let p_f = A::from_u32(p).unwrap(); let inv_p = A::one() / p_f; diff --git a/src/hpc/styles/amp.rs b/src/hpc/styles/amp.rs index 41c37718..f438418f 100644 --- a/src/hpc/styles/amp.rs +++ b/src/hpc/styles/amp.rs @@ -2,19 +2,33 @@ //! Science: Sutton & Barto (2018), Ashby (1956), Kahneman (2011). #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum GateState { Flow, Hold, Block } +pub enum GateState { + Flow, + Hold, + Block, +} #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum StyleRecommendation { KeepCurrent, TryNeighbor, RadicalShift } +pub enum StyleRecommendation { + KeepCurrent, + TryNeighbor, + RadicalShift, +} pub fn adaptive_style_select(gate_history: &[GateState]) -> StyleRecommendation { - if gate_history.is_empty() { return StyleRecommendation::KeepCurrent; } + if gate_history.is_empty() { + return StyleRecommendation::KeepCurrent; + } let recent: Vec<&GateState> = gate_history.iter().rev().take(3).collect(); let blocks = recent.iter().filter(|g| ***g == GateState::Block).count(); let flows = recent.iter().filter(|g| ***g == GateState::Flow).count(); - if blocks >= 3 { StyleRecommendation::RadicalShift } - else if blocks >= 1 { StyleRecommendation::TryNeighbor } - else { StyleRecommendation::KeepCurrent } + if blocks >= 3 { + StyleRecommendation::RadicalShift + } else if blocks >= 1 { + StyleRecommendation::TryNeighbor + } else { + StyleRecommendation::KeepCurrent + } } #[cfg(test)] @@ -30,6 +44,9 @@ mod tests { } #[test] fn test_mixed() { - assert_eq!(adaptive_style_select(&[GateState::Flow, GateState::Block, GateState::Flow]), StyleRecommendation::TryNeighbor); + assert_eq!( + adaptive_style_select(&[GateState::Flow, GateState::Block, GateState::Flow]), + StyleRecommendation::TryNeighbor + ); } } diff --git a/src/hpc/styles/are.rs b/src/hpc/styles/are.rs index f09cb0c8..bd01da78 100644 --- a/src/hpc/styles/are.rs +++ b/src/hpc/styles/are.rs @@ -11,17 +11,27 @@ pub enum TransformationType { } pub fn identify_transformation(inputs: &[Base17], outputs: &[Base17]) -> TransformationType { - if inputs.is_empty() || inputs.len() != outputs.len() { return TransformationType::Unknown; } + if inputs.is_empty() || inputs.len() != outputs.len() { + return TransformationType::Unknown; + } // Check if outputs = inputs + constant offset let mut offset = [0i16; 17]; - for d in 0..17 { offset[d] = outputs[0].dims[d].wrapping_sub(inputs[0].dims[d]); } - let is_offset = inputs.iter().zip(outputs.iter()).all(|(i, o)| { - (0..17).all(|d| o.dims[d].wrapping_sub(i.dims[d]) == offset[d]) - }); + for d in 0..17 { + offset[d] = outputs[0].dims[d].wrapping_sub(inputs[0].dims[d]); + } + let is_offset = inputs + .iter() + .zip(outputs.iter()) + .all(|(i, o)| (0..17).all(|d| o.dims[d].wrapping_sub(i.dims[d]) == offset[d])); if is_offset { - if offset == [0i16; 17] { TransformationType::Identity } - else { TransformationType::Offset(offset) } - } else { TransformationType::Unknown } + if offset == [0i16; 17] { + TransformationType::Identity + } else { + TransformationType::Offset(offset) + } + } else { + TransformationType::Unknown + } } #[cfg(test)] diff --git a/src/hpc/styles/cdi.rs b/src/hpc/styles/cdi.rs index 0b4cbf85..37947f63 100644 --- a/src/hpc/styles/cdi.rs +++ b/src/hpc/styles/cdi.rs @@ -13,12 +13,17 @@ pub fn induce_dissonance(belief: &Base17, truth: &NarsTruth, corpus: &[Base17]) let similarity = 1.0 - belief.l1(c) as f32 / max_l1; if similarity > 0.3 && similarity < 0.7 { let tension = similarity * (1.0 - similarity); // Maximum at 0.5 - if tension > best_tension { best_tension = tension; best = c.clone(); } + if tension > best_tension { + best_tension = tension; + best = c.clone(); + } } } // Dissonance = midpoint between belief and its tension partner let mut dims = [0i16; 17]; - for d in 0..17 { dims[d] = ((belief.dims[d] as i32 + best.dims[d] as i32) / 2) as i16; } + for d in 0..17 { + dims[d] = ((belief.dims[d] as i32 + best.dims[d] as i32) / 2) as i16; + } let dissonance = Base17 { dims }; let dissonant_truth = NarsTruth::new(0.5, truth.confidence * 0.5); (dissonance, dissonant_truth) diff --git a/src/hpc/styles/cdt.rs b/src/hpc/styles/cdt.rs index 8e0e7145..91dfd5c3 100644 --- a/src/hpc/styles/cdt.rs +++ b/src/hpc/styles/cdt.rs @@ -9,8 +9,11 @@ pub fn oscillate(query: &Base17, corpus: &[Base17], rounds: usize) -> (Base17, V for round in 0..rounds { if round % 2 == 0 { // Diverge: bundle with farthest neighbors (mean of distant items) - let mut farthest: Vec<(u32, usize)> = corpus.iter().enumerate() - .map(|(i, c)| (current.l1(c), i)).collect(); + let mut farthest: Vec<(u32, usize)> = corpus + .iter() + .enumerate() + .map(|(i, c)| (current.l1(c), i)) + .collect(); farthest.sort_by(|a, b| b.0.cmp(&a.0)); let top5: Vec<&Base17> = farthest.iter().take(5).map(|(_, i)| &corpus[*i]).collect(); current = bundle_base17(&top5, ¤t); @@ -21,7 +24,10 @@ pub fn oscillate(query: &Base17, corpus: &[Base17], rounds: usize) -> (Base17, V let mut best = current.clone(); for c in corpus { let d = current.l1(c); - if d < best_dist && d > 0 { best_dist = d; best = c.clone(); } + if d < best_dist && d > 0 { + best_dist = d; + best = c.clone(); + } } current = best; ratios.push(0.0); @@ -33,10 +39,18 @@ pub fn oscillate(query: &Base17, corpus: &[Base17], rounds: usize) -> (Base17, V fn bundle_base17(items: &[&Base17], seed: &Base17) -> Base17 { let n = items.len() as i32 + 1; let mut dims = [0i32; 17]; - for d in 0..17 { dims[d] += seed.dims[d] as i32; } - for item in items { for d in 0..17 { dims[d] += item.dims[d] as i32; } } + for d in 0..17 { + dims[d] += seed.dims[d] as i32; + } + for item in items { + for d in 0..17 { + dims[d] += item.dims[d] as i32; + } + } let mut result = [0i16; 17]; - for d in 0..17 { result[d] = (dims[d] / n) as i16; } + for d in 0..17 { + result[d] = (dims[d] / n) as i16; + } Base17 { dims: result } } @@ -46,7 +60,13 @@ mod tests { #[test] fn test_oscillate() { let query = Base17 { dims: [100; 17] }; - let corpus: Vec = (0..20).map(|i| { let mut d = [0i16; 17]; d[0] = (i*50) as i16; Base17 { dims: d } }).collect(); + let corpus: Vec = (0..20) + .map(|i| { + let mut d = [0i16; 17]; + d[0] = (i * 50) as i16; + Base17 { dims: d } + }) + .collect(); let (result, ratios) = oscillate(&query, &corpus, 4); assert_eq!(ratios.len(), 4); assert_eq!(ratios[0], 1.0); // diverge diff --git a/src/hpc/styles/cur.rs b/src/hpc/styles/cur.rs index 902045ee..af34dc0c 100644 --- a/src/hpc/styles/cur.rs +++ b/src/hpc/styles/cur.rs @@ -18,7 +18,15 @@ mod tests { use super::*; #[test] fn test_uncertainty_decreases() { - let dist = ClusterDistribution { mu: 5000.0, sigma: 1000.0, p25: 4325.5, p50: 5000.0, p75: 5674.5, p95: 6644.9, p99: 7326.3 }; + let dist = ClusterDistribution { + mu: 5000.0, + sigma: 1000.0, + p25: 4325.5, + p50: 5000.0, + p75: 5674.5, + p95: 6644.9, + p99: 7326.3, + }; let levels = cascading_uncertainty(&dist); assert!(levels[0].1 > levels[3].1); // coarsest has most uncertainty } diff --git a/src/hpc/styles/cws.rs b/src/hpc/styles/cws.rs index 61ec4f56..2c6099ce 100644 --- a/src/hpc/styles/cws.rs +++ b/src/hpc/styles/cws.rs @@ -8,7 +8,9 @@ pub struct Snapshot { } pub fn snapshot_region(corpus: &[(u16, Base17)]) -> Snapshot { - Snapshot { entries: corpus.to_vec() } + Snapshot { + entries: corpus.to_vec(), + } } pub fn restore_region(snapshot: &Snapshot) -> Vec<(u16, Base17)> { @@ -18,7 +20,9 @@ pub fn restore_region(snapshot: &Snapshot) -> Vec<(u16, Base17)> { pub fn merge_snapshots(a: &Snapshot, b: &Snapshot) -> Snapshot { let mut merged = a.entries.clone(); for (addr, fp) in &b.entries { - if !merged.iter().any(|(a, _)| a == addr) { merged.push((*addr, fp.clone())); } + if !merged.iter().any(|(a, _)| a == addr) { + merged.push((*addr, fp.clone())); + } } Snapshot { entries: merged } } diff --git a/src/hpc/styles/dtmf.rs b/src/hpc/styles/dtmf.rs index e00c2e65..1de3167c 100644 --- a/src/hpc/styles/dtmf.rs +++ b/src/hpc/styles/dtmf.rs @@ -10,14 +10,30 @@ pub struct FrameShift { } pub fn dynamic_reframe(gate_history: &[GateState]) -> FrameShift { - let recent_blocks = gate_history.iter().rev().take(3) - .filter(|g| **g == GateState::Block).count(); + let recent_blocks = gate_history + .iter() + .rev() + .take(3) + .filter(|g| **g == GateState::Block) + .count(); if recent_blocks >= 3 { - FrameShift { occurred: true, rung_jump: 3, style_flip: true } + FrameShift { + occurred: true, + rung_jump: 3, + style_flip: true, + } } else if recent_blocks >= 2 { - FrameShift { occurred: true, rung_jump: 1, style_flip: false } + FrameShift { + occurred: true, + rung_jump: 1, + style_flip: false, + } } else { - FrameShift { occurred: false, rung_jump: 0, style_flip: false } + FrameShift { + occurred: false, + rung_jump: 0, + style_flip: false, + } } } diff --git a/src/hpc/styles/etd.rs b/src/hpc/styles/etd.rs index 82b7339a..228ef465 100644 --- a/src/hpc/styles/etd.rs +++ b/src/hpc/styles/etd.rs @@ -3,7 +3,10 @@ use super::super::bgz17_bridge::Base17; -pub struct Subtask { pub fingerprint: Base17, pub relevance: f32 } +pub struct Subtask { + pub fingerprint: Base17, + pub relevance: f32, +} pub fn emergent_decompose(task: &Base17, corpus: &[Base17], max_subtasks: usize) -> Vec { let max_l1 = (17u32 * 65535) as f32; @@ -12,22 +15,40 @@ pub fn emergent_decompose(task: &Base17, corpus: &[Base17], max_subtasks: usize) let mut min_dists = vec![u32::MAX; corpus.len()]; for _ in 0..max_subtasks.min(corpus.len()) { // Update distances - let anchor = if selected.is_empty() { task } else { &corpus[*selected.last().unwrap()] }; + let anchor = if selected.is_empty() { + task + } else { + &corpus[*selected.last().unwrap()] + }; for (i, c) in corpus.iter().enumerate() { let d = anchor.l1(c); - if d < min_dists[i] { min_dists[i] = d; } + if d < min_dists[i] { + min_dists[i] = d; + } } // Pick farthest - let best = min_dists.iter().enumerate() + let best = min_dists + .iter() + .enumerate() .filter(|(i, _)| !selected.contains(i)) .max_by_key(|(_, d)| *d) .map(|(i, _)| i); - if let Some(idx) = best { selected.push(idx); } else { break; } + if let Some(idx) = best { + selected.push(idx); + } else { + break; + } } - selected.iter().map(|&i| { - let relevance = 1.0 - task.l1(&corpus[i]) as f32 / max_l1; - Subtask { fingerprint: corpus[i].clone(), relevance } - }).collect() + selected + .iter() + .map(|&i| { + let relevance = 1.0 - task.l1(&corpus[i]) as f32 / max_l1; + Subtask { + fingerprint: corpus[i].clone(), + relevance, + } + }) + .collect() } #[cfg(test)] @@ -36,7 +57,13 @@ mod tests { #[test] fn test_emergent_decompose() { let task = Base17 { dims: [100; 17] }; - let corpus: Vec = (0..20).map(|i| { let mut d = [0i16; 17]; d[0] = (i*100) as i16; Base17 { dims: d } }).collect(); + let corpus: Vec = (0..20) + .map(|i| { + let mut d = [0i16; 17]; + d[0] = (i * 100) as i16; + Base17 { dims: d } + }) + .collect(); let subtasks = emergent_decompose(&task, &corpus, 5); assert_eq!(subtasks.len(), 5); } diff --git a/src/hpc/styles/hkf.rs b/src/hpc/styles/hkf.rs index a6c3b5ae..e30c910b 100644 --- a/src/hpc/styles/hkf.rs +++ b/src/hpc/styles/hkf.rs @@ -16,7 +16,9 @@ pub fn cross_domain_fuse(domain_a: &Base17, domain_b: &Base17, relation: &Base17 let max_l1 = (17u32 * 65535) as f32; let mut fused_dims = [0i16; 17]; for d in 0..17 { - fused_dims[d] = domain_a.dims[d].wrapping_add(relation.dims[d]).wrapping_add(domain_b.dims[d]); + fused_dims[d] = domain_a.dims[d] + .wrapping_add(relation.dims[d]) + .wrapping_add(domain_b.dims[d]); } let fused = Base17 { dims: fused_dims }; @@ -24,8 +26,12 @@ pub fn cross_domain_fuse(domain_a: &Base17, domain_b: &Base17, relation: &Base17 let mut ra_dims = [0i16; 17]; let mut rb_dims = [0i16; 17]; for d in 0..17 { - ra_dims[d] = fused.dims[d].wrapping_sub(relation.dims[d]).wrapping_sub(domain_b.dims[d]); - rb_dims[d] = fused.dims[d].wrapping_sub(relation.dims[d]).wrapping_sub(domain_a.dims[d]); + ra_dims[d] = fused.dims[d] + .wrapping_sub(relation.dims[d]) + .wrapping_sub(domain_b.dims[d]); + rb_dims[d] = fused.dims[d] + .wrapping_sub(relation.dims[d]) + .wrapping_sub(domain_a.dims[d]); } let ra = Base17 { dims: ra_dims }; let rb = Base17 { dims: rb_dims }; diff --git a/src/hpc/styles/htd.rs b/src/hpc/styles/htd.rs index fa2aef12..ac7cc7a2 100644 --- a/src/hpc/styles/htd.rs +++ b/src/hpc/styles/htd.rs @@ -17,11 +17,7 @@ pub struct DecompositionTree { /// Hierarchical decompose: CLAM-style bipolar split. /// Find medoid, find farthest, partition into two clusters, recurse. /// Science: Ishaq et al. (2019), Dasgupta & Long (2005), Simon (1962). -pub fn hierarchical_decompose( - _query: &Base17, - corpus: &[Base17], - max_levels: usize, -) -> DecompositionTree { +pub fn hierarchical_decompose(_query: &Base17, corpus: &[Base17], max_levels: usize) -> DecompositionTree { let root = decompose_recursive(corpus, max_levels, 0); let depth = tree_depth(&root); DecompositionTree { root, depth } @@ -31,7 +27,9 @@ fn decompose_recursive(items: &[Base17], max_levels: usize, level: usize) -> Dec if items.is_empty() { return DecompositionNode { centroid: Base17 { dims: [0; 17] }, - radius: 0, count: 0, children: Vec::new(), + radius: 0, + count: 0, + children: Vec::new(), }; } @@ -40,13 +38,21 @@ fn decompose_recursive(items: &[Base17], max_levels: usize, level: usize) -> Dec let radius = items.iter().map(|i| centroid.l1(i)).max().unwrap_or(0); if items.len() <= 2 || level >= max_levels { - return DecompositionNode { centroid, radius, count: items.len(), children: Vec::new() }; + return DecompositionNode { + centroid, + radius, + count: items.len(), + children: Vec::new(), + }; } // Bipolar split: find farthest from centroid, partition - let farthest_idx = items.iter().enumerate() + let farthest_idx = items + .iter() + .enumerate() .max_by_key(|(_, i)| centroid.l1(i)) - .map(|(idx, _)| idx).unwrap_or(0); + .map(|(idx, _)| idx) + .unwrap_or(0); let pole = &items[farthest_idx]; let mut left = Vec::new(); @@ -61,7 +67,12 @@ fn decompose_recursive(items: &[Base17], max_levels: usize, level: usize) -> Dec // Guard against degenerate splits if left.is_empty() || right.is_empty() { - return DecompositionNode { centroid, radius, count: items.len(), children: Vec::new() }; + return DecompositionNode { + centroid, + radius, + count: items.len(), + children: Vec::new(), + }; } let children = vec![ @@ -69,23 +80,35 @@ fn decompose_recursive(items: &[Base17], max_levels: usize, level: usize) -> Dec decompose_recursive(&right, max_levels, level + 1), ]; - DecompositionNode { centroid, radius, count: items.len(), children } + DecompositionNode { + centroid, + radius, + count: items.len(), + children, + } } fn compute_centroid(items: &[Base17]) -> Base17 { let n = items.len() as i32; let mut dims = [0i32; 17]; for item in items { - for d in 0..17 { dims[d] += item.dims[d] as i32; } + for d in 0..17 { + dims[d] += item.dims[d] as i32; + } } let mut result = [0i16; 17]; - for d in 0..17 { result[d] = (dims[d] / n) as i16; } + for d in 0..17 { + result[d] = (dims[d] / n) as i16; + } Base17 { dims: result } } fn tree_depth(node: &DecompositionNode) -> usize { - if node.children.is_empty() { 0 } - else { 1 + node.children.iter().map(tree_depth).max().unwrap_or(0) } + if node.children.is_empty() { + 0 + } else { + 1 + node.children.iter().map(tree_depth).max().unwrap_or(0) + } } #[cfg(test)] @@ -94,11 +117,13 @@ mod tests { #[test] fn test_decompose_basic() { - let corpus: Vec = (0..20).map(|i| { - let mut dims = [0i16; 17]; - dims[0] = (i * 100) as i16; - Base17 { dims } - }).collect(); + let corpus: Vec = (0..20) + .map(|i| { + let mut dims = [0i16; 17]; + dims[0] = (i * 100) as i16; + Base17 { dims } + }) + .collect(); let query = Base17 { dims: [500; 17] }; let tree = hierarchical_decompose(&query, &corpus, 4); diff --git a/src/hpc/styles/icr.rs b/src/hpc/styles/icr.rs index 661d78f4..a745b247 100644 --- a/src/hpc/styles/icr.rs +++ b/src/hpc/styles/icr.rs @@ -11,31 +11,36 @@ pub struct CounterfactualWorld { pub truth: NarsTruth, } -pub fn iterate_counterfactuals( - base: &Base17, - interventions: &[Base17], - corpus: &[Base17], -) -> Vec { +pub fn iterate_counterfactuals(base: &Base17, interventions: &[Base17], corpus: &[Base17]) -> Vec { let max_l1 = (17u32 * 65535) as f32; - interventions.iter().enumerate().map(|(idx, intervention)| { - let mut modified_dims = [0i16; 17]; - for d in 0..17 { modified_dims[d] = base.dims[d].wrapping_add(intervention.dims[d]); } - let modified = Base17 { dims: modified_dims }; - let mut best_dist = u32::MAX; - let mut best = modified.clone(); - for c in corpus { - let d = modified.l1(c); - if d < best_dist { best_dist = d; best = c.clone(); } - } - let divergence = base.l1(&best) as f32 / max_l1; - let confidence = if best_dist < (max_l1 as u32) / 2 { 0.8 } else { 0.3 }; - CounterfactualWorld { - intervention_idx: idx, - resulting: best, - divergence, - truth: NarsTruth::new(1.0 - divergence, confidence), - } - }).collect() + interventions + .iter() + .enumerate() + .map(|(idx, intervention)| { + let mut modified_dims = [0i16; 17]; + for d in 0..17 { + modified_dims[d] = base.dims[d].wrapping_add(intervention.dims[d]); + } + let modified = Base17 { dims: modified_dims }; + let mut best_dist = u32::MAX; + let mut best = modified.clone(); + for c in corpus { + let d = modified.l1(c); + if d < best_dist { + best_dist = d; + best = c.clone(); + } + } + let divergence = base.l1(&best) as f32 / max_l1; + let confidence = if best_dist < (max_l1 as u32) / 2 { 0.8 } else { 0.3 }; + CounterfactualWorld { + intervention_idx: idx, + resulting: best, + divergence, + truth: NarsTruth::new(1.0 - divergence, confidence), + } + }) + .collect() } #[cfg(test)] diff --git a/src/hpc/styles/idr.rs b/src/hpc/styles/idr.rs index 6bc2b79a..35c4b423 100644 --- a/src/hpc/styles/idr.rs +++ b/src/hpc/styles/idr.rs @@ -4,20 +4,42 @@ use super::super::bgz17_bridge::Base17; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Intent { Analytical, Creative, Reflective, Focused, Default } +pub enum Intent { + Analytical, + Creative, + Reflective, + Focused, + Default, +} pub fn detect_intent(query: &Base17) -> Intent { // Use Base17 dimension distribution as proxy for intent let mean = query.dims.iter().map(|d| *d as i32).sum::() / 17; - let variance = query.dims.iter().map(|d| (*d as i32 - mean).pow(2)).sum::() / 17; - let activation = query.dims.iter().map(|d| d.unsigned_abs() as u32).sum::() / 17; + let variance = query + .dims + .iter() + .map(|d| (*d as i32 - mean).pow(2)) + .sum::() + / 17; + let activation = query + .dims + .iter() + .map(|d| d.unsigned_abs() as u32) + .sum::() + / 17; let direction = query.dims[0]; // sign of dim0 = dominant direction - if variance > 5000 && activation > 100 { Intent::Creative } - else if direction < 0 { Intent::Reflective } - else if activation < 100 { Intent::Focused } - else if variance < 1000 { Intent::Analytical } - else { Intent::Default } + if variance > 5000 && activation > 100 { + Intent::Creative + } else if direction < 0 { + Intent::Reflective + } else if activation < 100 { + Intent::Focused + } else if variance < 1000 { + Intent::Analytical + } else { + Intent::Default + } } #[cfg(test)] @@ -26,7 +48,9 @@ mod tests { #[test] fn test_high_variance_creative() { let mut dims = [0i16; 17]; - dims[0] = 1000; dims[1] = -1000; dims[2] = 500; + dims[0] = 1000; + dims[1] = -1000; + dims[2] = 500; assert_eq!(detect_intent(&Base17 { dims }), Intent::Creative); } #[test] diff --git a/src/hpc/styles/irs.rs b/src/hpc/styles/irs.rs index 1c5a9243..25e17b4c 100644 --- a/src/hpc/styles/irs.rs +++ b/src/hpc/styles/irs.rs @@ -12,11 +12,7 @@ pub struct PerspectiveResult { /// Perspective sweep: each role modulates the query via XOR-analog (dim-wise add), /// then the nearest in corpus is found. Novelty = L1 from accumulated perspectives. /// Science: Kanerva (2009) XOR binding, De Bono (1985), Galton (1907). -pub fn perspective_sweep( - query: &Base17, - roles: &[Base17], - corpus: &[Base17], -) -> Vec { +pub fn perspective_sweep(query: &Base17, roles: &[Base17], corpus: &[Base17]) -> Vec { let max_l1 = (17u32 * 65535) as f32; let mut results = Vec::new(); let mut seen = query.clone(); @@ -41,7 +37,11 @@ pub fn perspective_sweep( // Novelty: how different from accumulated perspectives let novelty = best.l1(&seen) as f32 / max_l1; - results.push(PerspectiveResult { role_idx: idx, result: best.clone(), novelty }); + results.push(PerspectiveResult { + role_idx: idx, + result: best.clone(), + novelty, + }); // Accumulate: running mean for d in 0..17 { @@ -60,16 +60,14 @@ mod tests { #[test] fn test_perspective_sweep() { let query = Base17 { dims: [100; 17] }; - let roles = vec![ - Base17 { dims: [10; 17] }, - Base17 { dims: [-50; 17] }, - Base17 { dims: [200; 17] }, - ]; - let corpus: Vec = (0..20).map(|i| { - let mut dims = [0i16; 17]; - dims[0] = (i * 50) as i16; - Base17 { dims } - }).collect(); + let roles = vec![Base17 { dims: [10; 17] }, Base17 { dims: [-50; 17] }, Base17 { dims: [200; 17] }]; + let corpus: Vec = (0..20) + .map(|i| { + let mut dims = [0i16; 17]; + dims[0] = (i * 50) as i16; + Base17 { dims } + }) + .collect(); let results = perspective_sweep(&query, &roles, &corpus); assert_eq!(results.len(), 3); @@ -82,12 +80,9 @@ mod tests { let query = Base17 { dims: [0; 17] }; let roles = vec![ Base17 { dims: [0; 17] }, // same as query -> low novelty - Base17 { dims: [10000; 17] }, // very different -> high novelty - ]; - let corpus = vec![ - Base17 { dims: [0; 17] }, - Base17 { dims: [10000; 17] }, + Base17 { dims: [10000; 17] }, // very different -> high novelty ]; + let corpus = vec![Base17 { dims: [0; 17] }, Base17 { dims: [10000; 17] }]; let results = perspective_sweep(&query, &roles, &corpus); assert!(results[0].novelty > results[1].novelty); } diff --git a/src/hpc/styles/lsi.rs b/src/hpc/styles/lsi.rs index bfb484ce..42cae744 100644 --- a/src/hpc/styles/lsi.rs +++ b/src/hpc/styles/lsi.rs @@ -4,24 +4,57 @@ pub struct ClusterDistribution { pub mu: f32, pub sigma: f32, - pub p25: f32, pub p50: f32, pub p75: f32, pub p95: f32, pub p99: f32, + pub p25: f32, + pub p50: f32, + pub p75: f32, + pub p95: f32, + pub p99: f32, } impl ClusterDistribution { pub fn from_distances(distances: &[u32]) -> Self { - if distances.is_empty() { return Self { mu: 0.0, sigma: 0.0, p25: 0.0, p50: 0.0, p75: 0.0, p95: 0.0, p99: 0.0 }; } + if distances.is_empty() { + return Self { + mu: 0.0, + sigma: 0.0, + p25: 0.0, + p50: 0.0, + p75: 0.0, + p95: 0.0, + p99: 0.0, + }; + } let n = distances.len() as f32; let mu = distances.iter().sum::() as f32 / n; - let sigma = (distances.iter().map(|d| (*d as f32 - mu).powi(2)).sum::() / n).sqrt(); - Self { mu, sigma, p25: mu - 0.6745 * sigma, p50: mu, p75: mu + 0.6745 * sigma, p95: mu + 1.6449 * sigma, p99: mu + 2.3263 * sigma } + let sigma = (distances + .iter() + .map(|d| (*d as f32 - mu).powi(2)) + .sum::() + / n) + .sqrt(); + Self { + mu, + sigma, + p25: mu - 0.6745 * sigma, + p50: mu, + p75: mu + 0.6745 * sigma, + p95: mu + 1.6449 * sigma, + p99: mu + 2.3263 * sigma, + } } pub fn mexican_hat(&self, distance: f32) -> f32 { - if distance < self.p25 { 1.0 } - else if distance < self.p75 { 0.5 } - else if distance < self.p95 { 0.0 } - else if distance < self.p99 { -0.5 } - else { -1.0 } + if distance < self.p25 { + 1.0 + } else if distance < self.p75 { + 0.5 + } else if distance < self.p95 { + 0.0 + } else if distance < self.p99 { + -0.5 + } else { + -1.0 + } } } diff --git a/src/hpc/styles/mcp.rs b/src/hpc/styles/mcp.rs index 6431abdf..12cc9fa6 100644 --- a/src/hpc/styles/mcp.rs +++ b/src/hpc/styles/mcp.rs @@ -19,7 +19,11 @@ pub struct MetaCognition { impl MetaCognition { pub fn new(max_history: usize) -> Self { - Self { history: Vec::new(), max_history, calibration_error: 0.5 } + Self { + history: Vec::new(), + max_history, + calibration_error: 0.5, + } } /// Assess meta-confidence: how reliable is our confidence? @@ -31,9 +35,7 @@ impl MetaCognition { } let mean = self.history.iter().sum::() / self.history.len() as f32; - let variance = self.history.iter() - .map(|c| (c - mean).powi(2)) - .sum::() / self.history.len() as f32; + let variance = self.history.iter().map(|c| (c - mean).powi(2)).sum::() / self.history.len() as f32; let meta_confidence = 1.0 - variance.sqrt(); diff --git a/src/hpc/styles/mct.rs b/src/hpc/styles/mct.rs index 384a2006..3d15893a 100644 --- a/src/hpc/styles/mct.rs +++ b/src/hpc/styles/mct.rs @@ -4,12 +4,19 @@ use super::super::bgz17_bridge::Base17; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Modality { Text, Image, Audio, Code } +pub enum Modality { + Text, + Image, + Audio, + Code, +} pub fn cross_modal_bind(text: &Base17, image: &Base17, relation: &Base17) -> Base17 { let mut dims = [0i16; 17]; for d in 0..17 { - dims[d] = text.dims[d].wrapping_add(relation.dims[d]).wrapping_add(image.dims[d]); + dims[d] = text.dims[d] + .wrapping_add(relation.dims[d]) + .wrapping_add(image.dims[d]); } Base17 { dims } } @@ -25,7 +32,11 @@ pub fn fusion_quality(fused: &Base17, parent_a: &Base17, parent_b: &Base17, rela fn recover(fused: &Base17, other: &Base17, relation: &Base17) -> Base17 { let mut dims = [0i16; 17]; - for d in 0..17 { dims[d] = fused.dims[d].wrapping_sub(relation.dims[d]).wrapping_sub(other.dims[d]); } + for d in 0..17 { + dims[d] = fused.dims[d] + .wrapping_sub(relation.dims[d]) + .wrapping_sub(other.dims[d]); + } Base17 { dims } } diff --git a/src/hpc/styles/mpc.rs b/src/hpc/styles/mpc.rs index 40f9a685..c29dcec2 100644 --- a/src/hpc/styles/mpc.rs +++ b/src/hpc/styles/mpc.rs @@ -4,16 +4,22 @@ use super::super::bgz17_bridge::Base17; pub fn weighted_bundle(items: &[(&Base17, f32)]) -> Base17 { - if items.is_empty() { return Base17 { dims: [0; 17] }; } + if items.is_empty() { + return Base17 { dims: [0; 17] }; + } let mut dims = [0f64; 17]; let mut total_weight = 0f64; for (fp, weight) in items { - for d in 0..17 { dims[d] += fp.dims[d] as f64 * *weight as f64; } + for d in 0..17 { + dims[d] += fp.dims[d] as f64 * *weight as f64; + } total_weight += *weight as f64; } let mut result = [0i16; 17]; if total_weight > 0.0 { - for d in 0..17 { result[d] = (dims[d] / total_weight).round() as i16; } + for d in 0..17 { + result[d] = (dims[d] / total_weight).round() as i16; + } } Base17 { dims: result } } diff --git a/src/hpc/styles/pso.rs b/src/hpc/styles/pso.rs index 739d102c..d223eb92 100644 --- a/src/hpc/styles/pso.rs +++ b/src/hpc/styles/pso.rs @@ -13,7 +13,15 @@ pub struct FieldModulation { impl Default for FieldModulation { fn default() -> Self { - Self { resonance_threshold: 0.7, fan_out: 6, depth_bias: 0.5, breadth_bias: 0.5, noise_tolerance: 0.3, speed_bias: 0.5, exploration: 0.3 } + Self { + resonance_threshold: 0.7, + fan_out: 6, + depth_bias: 0.5, + breadth_bias: 0.5, + noise_tolerance: 0.3, + speed_bias: 0.5, + exploration: 0.3, + } } } diff --git a/src/hpc/styles/rte.rs b/src/hpc/styles/rte.rs index 7a3277f3..69b9c4fe 100644 --- a/src/hpc/styles/rte.rs +++ b/src/hpc/styles/rte.rs @@ -20,7 +20,10 @@ pub struct ExpansionTrace { impl RecursiveExpansion { pub fn new(max_depth: u8, convergence_threshold: f32) -> Self { - Self { max_depth, convergence_threshold } + Self { + max_depth, + convergence_threshold, + } } /// Apply recursive expansion: output of depth N becomes input to depth N+1. @@ -42,11 +45,19 @@ impl RecursiveExpansion { } let max_l1 = (17 * 65535) as f32; let delta = best_dist as f32 / max_l1; - steps.push(ExpansionStep { depth, delta, fingerprint: best.clone() }); - if delta < self.convergence_threshold { break; } + steps.push(ExpansionStep { + depth, + delta, + fingerprint: best.clone(), + }); + if delta < self.convergence_threshold { + break; + } current = best; } - let converged = steps.last().map_or(false, |s| s.delta < self.convergence_threshold); + let converged = steps + .last() + .map_or(false, |s| s.delta < self.convergence_threshold); ExpansionTrace { steps, converged } } } @@ -58,12 +69,16 @@ mod tests { #[test] fn test_recursive_expansion_converges() { let seed = Base17 { dims: [100; 17] }; - let corpus: Vec = (0..10).map(|i| { - let mut dims = [0i16; 17]; - dims[0] = 100 - (i * 5) as i16; - for d in 1..17 { dims[d] = 100 - (i * 3) as i16; } - Base17 { dims } - }).collect(); + let corpus: Vec = (0..10) + .map(|i| { + let mut dims = [0i16; 17]; + dims[0] = 100 - (i * 5) as i16; + for d in 1..17 { + dims[d] = 100 - (i * 3) as i16; + } + Base17 { dims } + }) + .collect(); let re = RecursiveExpansion::new(7, 0.001); let trace = re.expand(&seed, &corpus); diff --git a/src/hpc/styles/sdd.rs b/src/hpc/styles/sdd.rs index 8d9aef18..adf12591 100644 --- a/src/hpc/styles/sdd.rs +++ b/src/hpc/styles/sdd.rs @@ -16,7 +16,11 @@ pub fn detect_distortion(original: &Base17, transformed: &Base17, dist: &Cluster DistortionReport { information_loss: (raw - noise_floor).max(0.0) / (17.0 * 65535.0), structural_drift: if dist.mu > 0.0 { raw / dist.mu } else { 0.0 }, - z_score: if dist.sigma > 0.0 { (raw - dist.p50) / dist.sigma } else { 0.0 }, + z_score: if dist.sigma > 0.0 { + (raw - dist.p50) / dist.sigma + } else { + 0.0 + }, } } @@ -26,7 +30,15 @@ mod tests { #[test] fn test_no_distortion() { let a = Base17 { dims: [100; 17] }; - let dist = ClusterDistribution { mu: 5000.0, sigma: 1000.0, p25: 4325.0, p50: 5000.0, p75: 5675.0, p95: 6645.0, p99: 7326.0 }; + let dist = ClusterDistribution { + mu: 5000.0, + sigma: 1000.0, + p25: 4325.0, + p50: 5000.0, + p75: 5675.0, + p95: 6645.0, + p99: 7326.0, + }; let report = detect_distortion(&a, &a, &dist); assert_eq!(report.information_loss, 0.0); assert!(report.z_score < 0.0); // below median @@ -35,7 +47,15 @@ mod tests { fn test_high_distortion() { let a = Base17 { dims: [0; 17] }; let b = Base17 { dims: [30000; 17] }; - let dist = ClusterDistribution { mu: 5000.0, sigma: 1000.0, p25: 4325.0, p50: 5000.0, p75: 5675.0, p95: 6645.0, p99: 7326.0 }; + let dist = ClusterDistribution { + mu: 5000.0, + sigma: 1000.0, + p25: 4325.0, + p50: 5000.0, + p75: 5675.0, + p95: 6645.0, + p99: 7326.0, + }; let report = detect_distortion(&a, &b, &dist); assert!(report.z_score > 2.0); // way above p99 assert!(report.structural_drift > 1.0); diff --git a/src/hpc/styles/smad.rs b/src/hpc/styles/smad.rs index 8830136d..3e9e0e7b 100644 --- a/src/hpc/styles/smad.rs +++ b/src/hpc/styles/smad.rs @@ -1,7 +1,7 @@ //! #3 Structured Multi-Agent Debate — bundle + NARS revision on Base17 fingerprints. use super::super::bgz17_bridge::Base17; -use super::super::nars::{NarsTruth, nars_revision}; +use super::super::nars::{nars_revision, NarsTruth}; /// One proposition in a debate: a fingerprint + truth value. pub struct Proposition { @@ -20,11 +20,7 @@ pub struct DebateResult { /// Run structured debate: each "agent" is a Base17 perspective. /// Perspectives are bundled (majority vote per dim), truth values revised. /// Science: Wang (2006), Du et al. (2023), Kanerva (2009). -pub fn debate( - input: &Base17, - perspectives: &[Base17], - rounds: u8, -) -> DebateResult { +pub fn debate(input: &Base17, perspectives: &[Base17], rounds: u8) -> DebateResult { let mut propositions = Vec::new(); for perspective in perspectives { @@ -34,14 +30,22 @@ pub fn debate( let max_l1 = (17u32 * 65535) as f32; let resonance = 1.0 - (dist as f32 / max_l1); let truth = NarsTruth::from_evidence( - resonance * 10.0, // positive evidence proportional to resonance + resonance * 10.0, // positive evidence proportional to resonance (1.0 - resonance) * 10.0, // negative evidence proportional to distance ); - propositions.push(Proposition { fingerprint: perspective.clone(), truth }); + propositions.push(Proposition { + fingerprint: perspective.clone(), + truth, + }); } // Bundle: majority vote per dimension (mean of i16 values) - let consensus = bundle_base17(&propositions.iter().map(|p| &p.fingerprint).collect::>()); + let consensus = bundle_base17( + &propositions + .iter() + .map(|p| &p.fingerprint) + .collect::>(), + ); // NARS revision across all truth values let mut consensus_truth = NarsTruth::new(0.5, 0.0); @@ -49,19 +53,30 @@ pub fn debate( consensus_truth = nars_revision(consensus_truth, prop.truth); } - DebateResult { consensus, truth: consensus_truth, rounds, propositions } + DebateResult { + consensus, + truth: consensus_truth, + rounds, + propositions, + } } /// Bundle Base17 fingerprints: mean per dimension (majority vote analog). fn bundle_base17(fps: &[&Base17]) -> Base17 { - if fps.is_empty() { return Base17 { dims: [0; 17] }; } + if fps.is_empty() { + return Base17 { dims: [0; 17] }; + } let n = fps.len() as i32; let mut sums = [0i32; 17]; for fp in fps { - for d in 0..17 { sums[d] += fp.dims[d] as i32; } + for d in 0..17 { + sums[d] += fp.dims[d] as i32; + } } let mut dims = [0i16; 17]; - for d in 0..17 { dims[d] = (sums[d] / n) as i16; } + for d in 0..17 { + dims[d] = (sums[d] / n) as i16; + } Base17 { dims } } @@ -72,11 +87,7 @@ mod tests { #[test] fn test_debate_consensus() { let input = Base17 { dims: [100; 17] }; - let perspectives = vec![ - Base17 { dims: [90; 17] }, - Base17 { dims: [110; 17] }, - Base17 { dims: [100; 17] }, - ]; + let perspectives = vec![Base17 { dims: [90; 17] }, Base17 { dims: [110; 17] }, Base17 { dims: [100; 17] }]; let result = debate(&input, &perspectives, 1); assert_eq!(result.propositions.len(), 3); assert!(result.truth.confidence > 0.0); @@ -87,11 +98,13 @@ mod tests { #[test] fn test_debate_truth_accumulates() { let input = Base17 { dims: [50; 17] }; - let perspectives: Vec = (0..5).map(|i| { - let mut dims = [50i16; 17]; - dims[0] += (i * 10) as i16; - Base17 { dims } - }).collect(); + let perspectives: Vec = (0..5) + .map(|i| { + let mut dims = [50i16; 17]; + dims[0] += (i * 10) as i16; + Base17 { dims } + }) + .collect(); let result = debate(&input, &perspectives, 1); // More perspectives → higher confidence assert!(result.truth.confidence > 0.5); diff --git a/src/hpc/styles/spp.rs b/src/hpc/styles/spp.rs index 1c8b4188..a21ce618 100644 --- a/src/hpc/styles/spp.rs +++ b/src/hpc/styles/spp.rs @@ -9,7 +9,9 @@ pub struct ShadowResult { pub fn precompute_shadows(current: &Base17, corpus: &[Base17], depth: usize, top_k: usize) -> Vec { // Level 1: neighbors of current - let mut neighbors: Vec<(usize, u32)> = corpus.iter().enumerate() + let mut neighbors: Vec<(usize, u32)> = corpus + .iter() + .enumerate() .map(|(i, c)| (i, current.l1(c))) .collect(); neighbors.sort_by_key(|&(_, d)| d); @@ -19,7 +21,9 @@ pub fn precompute_shadows(current: &Base17, corpus: &[Base17], depth: usize, top // Level 2: for each neighbor, find ITS neighbors if depth > 1 { for &(idx, _) in &neighbors { - let mut sub: Vec<(usize, u32)> = corpus.iter().enumerate() + let mut sub: Vec<(usize, u32)> = corpus + .iter() + .enumerate() .map(|(i, c)| (i, corpus[idx].l1(c))) .collect(); sub.sort_by_key(|&(_, d)| d); @@ -39,7 +43,13 @@ mod tests { #[test] fn test_shadow_precompute() { let current = Base17 { dims: [100; 17] }; - let corpus: Vec = (0..20).map(|i| { let mut d = [0i16; 17]; d[0] = (i*50) as i16; Base17 { dims: d } }).collect(); + let corpus: Vec = (0..20) + .map(|i| { + let mut d = [0i16; 17]; + d[0] = (i * 50) as i16; + Base17 { dims: d } + }) + .collect(); let shadows = precompute_shadows(¤t, &corpus, 2, 5); assert!(!shadows.is_empty()); assert!(shadows[0].predictions.len() <= 5); diff --git a/src/hpc/styles/ssam.rs b/src/hpc/styles/ssam.rs index 8a679963..2b00e562 100644 --- a/src/hpc/styles/ssam.rs +++ b/src/hpc/styles/ssam.rs @@ -11,19 +11,36 @@ pub struct AnalogyResult { pub fn structural_analogy(relation: &Base17, domain: &[Base17], corpus: &[Base17]) -> Vec { let max_l1 = (17u32 * 65535) as f32; - domain.iter().enumerate().filter_map(|(idx, c)| { - let mut predicted_dims = [0i16; 17]; - for d in 0..17 { predicted_dims[d] = c.dims[d].wrapping_add(relation.dims[d]); } - let predicted = Base17 { dims: predicted_dims }; - let mut best_dist = u32::MAX; - let mut best = predicted.clone(); - for target in corpus { - let d = predicted.l1(target); - if d < best_dist { best_dist = d; best = target.clone(); } - } - let strength = 1.0 - best_dist as f32 / max_l1; - if strength > 0.6 { Some(AnalogyResult { source_idx: idx, predicted: best, strength }) } else { None } - }).collect() + domain + .iter() + .enumerate() + .filter_map(|(idx, c)| { + let mut predicted_dims = [0i16; 17]; + for d in 0..17 { + predicted_dims[d] = c.dims[d].wrapping_add(relation.dims[d]); + } + let predicted = Base17 { dims: predicted_dims }; + let mut best_dist = u32::MAX; + let mut best = predicted.clone(); + for target in corpus { + let d = predicted.l1(target); + if d < best_dist { + best_dist = d; + best = target.clone(); + } + } + let strength = 1.0 - best_dist as f32 / max_l1; + if strength > 0.6 { + Some(AnalogyResult { + source_idx: idx, + predicted: best, + strength, + }) + } else { + None + } + }) + .collect() } #[cfg(test)] diff --git a/src/hpc/styles/ssr.rs b/src/hpc/styles/ssr.rs index 6c81b7c7..721a806b 100644 --- a/src/hpc/styles/ssr.rs +++ b/src/hpc/styles/ssr.rs @@ -9,11 +9,19 @@ pub struct SkepticismSchedule { } impl SkepticismSchedule { - pub fn new(base: f32) -> Self { Self { consecutive_confident: 0, base_skepticism: base } } + pub fn new(base: f32) -> Self { + Self { + consecutive_confident: 0, + base_skepticism: base, + } + } pub fn update(&mut self, truth: &NarsTruth) -> f32 { - if truth.confidence > 0.8 { self.consecutive_confident += 1; } - else { self.consecutive_confident = 0; } + if truth.confidence > 0.8 { + self.consecutive_confident += 1; + } else { + self.consecutive_confident = 0; + } self.base_skepticism + (self.consecutive_confident as f32 + 1.0).ln() * 0.1 } } diff --git a/src/hpc/styles/tca.rs b/src/hpc/styles/tca.rs index 4c6d8a93..fb91d1fd 100644 --- a/src/hpc/styles/tca.rs +++ b/src/hpc/styles/tca.rs @@ -19,11 +19,7 @@ pub struct TemporalFingerprint { /// Augment a Base17 fingerprint with temporal context. /// Recency decays with time distance from reference. /// Science: Reichenbach (1947), Kamp & Reyle (1993 Ch.5), Vendler (1957). -pub fn temporalize( - base: &Base17, - event_time: u64, - reference_time: u64, -) -> TemporalFingerprint { +pub fn temporalize(base: &Base17, event_time: u64, reference_time: u64) -> TemporalFingerprint { let speech_time = reference_time; // default: now = reference let time_delta = if event_time > reference_time { event_time - reference_time @@ -36,7 +32,11 @@ pub fn temporalize( TemporalFingerprint { base: base.clone(), - temporal: TemporalContext { event_time, reference_time, speech_time }, + temporal: TemporalContext { + event_time, + reference_time, + speech_time, + }, recency, } } diff --git a/src/hpc/styles/tcf.rs b/src/hpc/styles/tcf.rs index db0ebec7..5e04ab17 100644 --- a/src/hpc/styles/tcf.rs +++ b/src/hpc/styles/tcf.rs @@ -4,12 +4,11 @@ use super::super::bgz17_bridge::Base17; pub fn cascade_filter( - query: &Base17, - corpus: &[Base17], - quality_fn: &dyn Fn(&Base17) -> f32, - top_k: usize, + query: &Base17, corpus: &[Base17], quality_fn: &dyn Fn(&Base17) -> f32, top_k: usize, ) -> Vec<(usize, f32)> { - let mut scored: Vec<(usize, f32)> = corpus.iter().enumerate() + let mut scored: Vec<(usize, f32)> = corpus + .iter() + .enumerate() .map(|(i, c)| (i, quality_fn(c))) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); @@ -23,7 +22,13 @@ mod tests { #[test] fn test_cascade_filter() { let query = Base17 { dims: [100; 17] }; - let corpus: Vec = (0..10).map(|i| { let mut d = [0i16; 17]; d[0] = (i * 100) as i16; Base17 { dims: d } }).collect(); + let corpus: Vec = (0..10) + .map(|i| { + let mut d = [0i16; 17]; + d[0] = (i * 100) as i16; + Base17 { dims: d } + }) + .collect(); let results = cascade_filter(&query, &corpus, &|c| -(query.l1(c) as f32), 3); assert_eq!(results.len(), 3); assert!(results[0].1 >= results[1].1); // sorted by quality diff --git a/src/hpc/styles/tcp.rs b/src/hpc/styles/tcp.rs index 8c642462..c07ec4c1 100644 --- a/src/hpc/styles/tcp.rs +++ b/src/hpc/styles/tcp.rs @@ -10,7 +10,10 @@ pub struct ChainPruner { impl ChainPruner { /// Default: Berry-Esseen noise floor at d=17 (Base17 dimensions). pub fn new(max_branches: usize) -> Self { - Self { noise_floor: 0.01, max_branches } + Self { + noise_floor: 0.01, + max_branches, + } } /// Prune chain: keep branches where L1 from accumulated bundle exceeds noise floor. @@ -29,7 +32,9 @@ impl ChainPruner { bundle.dims[d] = ((bundle.dims[d] as i32 + chain[i].dims[d] as i32) / 2) as i16; } } - if kept.len() >= self.max_branches { break; } + if kept.len() >= self.max_branches { + break; + } } kept } @@ -43,9 +48,9 @@ mod tests { fn test_prune_keeps_novel() { let chain = vec![ Base17 { dims: [0; 17] }, - Base17 { dims: [1000; 17] }, // very different -> keep - Base17 { dims: [1; 17] }, // near duplicate of bundle -> prune - Base17 { dims: [2000; 17] }, // very different -> keep + Base17 { dims: [1000; 17] }, // very different -> keep + Base17 { dims: [1; 17] }, // near duplicate of bundle -> prune + Base17 { dims: [2000; 17] }, // very different -> keep ]; let pruner = ChainPruner::new(10); let kept = pruner.prune(&chain); @@ -56,9 +61,11 @@ mod tests { #[test] fn test_prune_respects_max() { - let chain: Vec = (0..20).map(|i| { - Base17 { dims: [(i * 1000) as i16; 17] } - }).collect(); + let chain: Vec = (0..20) + .map(|i| Base17 { + dims: [(i * 1000) as i16; 17], + }) + .collect(); let pruner = ChainPruner::new(3); let kept = pruner.prune(&chain); assert!(kept.len() <= 3); diff --git a/src/hpc/styles/zcf.rs b/src/hpc/styles/zcf.rs index 2854ba0c..00ce3c34 100644 --- a/src/hpc/styles/zcf.rs +++ b/src/hpc/styles/zcf.rs @@ -13,7 +13,9 @@ pub struct FusionResult { pub fn fuse(a: &Base17, b: &Base17) -> FusionResult { let max_l1 = (17u32 * 65535) as f32; let mut fused_dims = [0i16; 17]; - for d in 0..17 { fused_dims[d] = a.dims[d].wrapping_add(b.dims[d]); } + for d in 0..17 { + fused_dims[d] = a.dims[d].wrapping_add(b.dims[d]); + } let fused = Base17 { dims: fused_dims }; let mut recover_a_dims = [0i16; 17]; let mut recover_b_dims = [0i16; 17]; diff --git a/src/hpc/substrate.rs b/src/hpc/substrate.rs index b2697ec6..a78713cf 100644 --- a/src/hpc/substrate.rs +++ b/src/hpc/substrate.rs @@ -65,12 +65,8 @@ pub enum Substrate { impl Substrate { /// All four substrates in priority order. - pub const ALL: [Substrate; 4] = [ - Substrate::Structural, - Substrate::Soaking, - Substrate::Evidential, - Substrate::Semantic, - ]; + pub const ALL: [Substrate; 4] = + [Substrate::Structural, Substrate::Soaking, Substrate::Evidential, Substrate::Semantic]; /// Convert from raw u8. #[inline] @@ -129,12 +125,7 @@ impl SubstrateRoute { /// Create a dual-substrate route. #[inline] - pub fn dual( - primary: Substrate, - secondary: Substrate, - primary_weight: f32, - parallel: bool, - ) -> Self { + pub fn dual(primary: Substrate, secondary: Substrate, primary_weight: f32, parallel: bool) -> Self { Self { primary, secondary: Some(secondary), @@ -147,11 +138,7 @@ impl SubstrateRoute { /// Create a triple-substrate route. #[inline] pub fn triple( - primary: Substrate, - secondary: Substrate, - tertiary: Substrate, - primary_weight: f32, - parallel: bool, + primary: Substrate, secondary: Substrate, tertiary: Substrate, primary_weight: f32, parallel: bool, ) -> Self { Self { primary, @@ -183,9 +170,7 @@ impl SubstrateRoute { /// Whether this route uses a specific substrate. #[inline] pub fn uses(&self, substrate: Substrate) -> bool { - self.primary == substrate - || self.secondary == Some(substrate) - || self.tertiary == Some(substrate) + self.primary == substrate || self.secondary == Some(substrate) || self.tertiary == Some(substrate) } } @@ -375,10 +360,7 @@ pub fn coherence(signals: &SubstrateSignals) -> Coherence { /// but this provides the substrate-level recommendation. /// /// Returns `None` if no transition is warranted. -pub fn recommend_transition( - current_primary: Substrate, - signals: &SubstrateSignals, -) -> Option { +pub fn recommend_transition(current_primary: Substrate, signals: &SubstrateSignals) -> Option { // === CRYSTALLIZATION: soaking saturated → move to structural === if signals.soaking_saturation > 0.85 && signals.theta_average > 100.0 @@ -388,10 +370,7 @@ pub fn recommend_transition( } // === DOUBT: NARS confidence dropping → move to evidential === - if signals.nars_confidence < 0.3 - && signals.nars_contradictions > 2 - && current_primary == Substrate::Structural - { + if signals.nars_confidence < 0.3 && signals.nars_contradictions > 2 && current_primary == Substrate::Structural { return Some(Substrate::Evidential); // need to re-examine } @@ -413,10 +392,7 @@ pub fn recommend_transition( } // === ASSOCIATION: structural miss but semantic hit → move to semantic === - if signals.structural_hits == 0 - && signals.semantic_nearest > 0.5 - && current_primary == Substrate::Structural - { + if signals.structural_hits == 0 && signals.semantic_nearest > 0.5 && current_primary == Substrate::Structural { return Some(Substrate::Semantic); // try analogy instead } @@ -482,16 +458,11 @@ impl SubstrateSnapshot { let coh = coherence(signals); // Capture soaking state if soaking was active - let (theta, maturity, saturation) = - if route.uses(Substrate::Soaking) && signals.has_soaking() { - ( - Some(signals.theta_average), - Some(signals.maturity_average as u8), - Some(signals.soaking_saturation), - ) - } else { - (None, None, None) - }; + let (theta, maturity, saturation) = if route.uses(Substrate::Soaking) && signals.has_soaking() { + (Some(signals.theta_average), Some(signals.maturity_average as u8), Some(signals.soaking_saturation)) + } else { + (None, None, None) + }; // Capture NARS state if evidential was active let nars = if route.uses(Substrate::Evidential) && signals.has_evidential() { @@ -622,13 +593,8 @@ mod tests { #[test] fn test_substrate_route_triple() { - let route = SubstrateRoute::triple( - Substrate::Structural, - Substrate::Evidential, - Substrate::Semantic, - 0.6, - true, - ); + let route = + SubstrateRoute::triple(Substrate::Structural, Substrate::Evidential, Substrate::Semantic, 0.6, true); assert_eq!(route.depth(), 3); assert!(route.uses(Substrate::Structural)); assert!(route.uses(Substrate::Evidential)); diff --git a/src/hpc/surround_metadata.rs b/src/hpc/surround_metadata.rs index b4ddceda..9c702983 100644 --- a/src/hpc/surround_metadata.rs +++ b/src/hpc/surround_metadata.rs @@ -55,10 +55,7 @@ impl SurroundBundler { phases.push(Self::compute_phases(i, n_atoms)); } let noise_floor = Self::euler_gamma_noise_floor(); - SurroundBundler { - phases, - noise_floor, - } + SurroundBundler { phases, noise_floor } } /// Compute phase angles for atom `i` across all rotation planes. @@ -281,14 +278,8 @@ impl SurroundMetadata { /// Create from 7 components (S, P, O, T, K, M, L). /// S, P, O are dense Fingerprint<256> (16Kbit). /// T, K, M, L are sparse Fingerprint<256> (16Kbit, low density). - pub fn from_components( - components: &[Fingerprint<256>; 7], - bundler: &SurroundBundler, - ) -> Self { - let f64_atoms: Vec> = components - .iter() - .map(fingerprint_to_f64) - .collect(); + pub fn from_components(components: &[Fingerprint<256>; 7], bundler: &SurroundBundler) -> Self { + let f64_atoms: Vec> = components.iter().map(fingerprint_to_f64).collect(); let f64_bundle = bundler.bundle_raw(&f64_atoms); let bundle = f64_to_fingerprint_128(&f64_bundle); SurroundMetadata { bundle, f64_bundle } @@ -363,7 +354,9 @@ pub fn random_fingerprint_256(seed: u64) -> Fingerprint<256> { let mut words = [0u64; 256]; let mut state = seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1); for w in words.iter_mut() { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); *w = state; } Fingerprint::from_words(words) @@ -377,7 +370,9 @@ pub fn sparse_fingerprint_256(seed: u64, density: f64) -> Fingerprint<256> { for w in words.iter_mut() { let mut word = 0u64; for bit in 0..64 { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); if state < threshold { word |= 1u64 << bit; } @@ -452,9 +447,9 @@ mod tests { // Create 7 components with realistic density let components: [Fingerprint<256>; 7] = [ - random_fingerprint_256(100), // S: dense - random_fingerprint_256(200), // P: dense - random_fingerprint_256(300), // O: dense + random_fingerprint_256(100), // S: dense + random_fingerprint_256(200), // P: dense + random_fingerprint_256(300), // O: dense sparse_fingerprint_256(400, 0.1), // T: sparse sparse_fingerprint_256(500, 0.1), // K: sparse sparse_fingerprint_256(600, 0.1), // M: sparse @@ -462,10 +457,7 @@ mod tests { ]; // Convert to f64 - let f64_atoms: Vec> = components - .iter() - .map(fingerprint_to_f64) - .collect(); + let f64_atoms: Vec> = components.iter().map(fingerprint_to_f64).collect(); // Bundle let bundle = bundler.bundle_raw(&f64_atoms); @@ -501,29 +493,15 @@ mod tests { ); } - eprintln!( - "\n Mean fidelity: {:.6}", - fidelities.iter().sum::() / fidelities.len() as f64 - ); - eprintln!( - " Min fidelity: {:.6}", - fidelities.iter().cloned().fold(f64::INFINITY, f64::min) - ); + eprintln!("\n Mean fidelity: {:.6}", fidelities.iter().sum::() / fidelities.len() as f64); + eprintln!(" Min fidelity: {:.6}", fidelities.iter().cloned().fold(f64::INFINITY, f64::min)); // THRESHOLD: 100% classification accuracy (surround must beat mono) - assert!( - all_pass, - "FAIL: Not all 7 components correctly classified after recovery" - ); + assert!(all_pass, "FAIL: Not all 7 components correctly classified after recovery"); // Record: fidelity for each component for (i, f) in fidelities.iter().enumerate() { - assert!( - *f > 0.05, - "Component {} fidelity too low: {}", - COMPONENT_NAMES[i], - f - ); + assert!(*f > 0.05, "Component {} fidelity too low: {}", COMPONENT_NAMES[i], f); } } @@ -550,10 +528,7 @@ mod tests { sparse_fingerprint_256(base_seed + 6, 0.1), ]; - let f64_atoms: Vec> = components - .iter() - .map(fingerprint_to_f64) - .collect(); + let f64_atoms: Vec> = components.iter().map(fingerprint_to_f64).collect(); let bundle = bundler.bundle_raw(&f64_atoms); @@ -584,11 +559,7 @@ mod tests { ); // THRESHOLD: >95% classification accuracy across 700 recoveries - assert!( - accuracy > 0.95, - "FAIL: Classification accuracy {:.2}% < 95%", - accuracy * 100.0 - ); + assert!(accuracy > 0.95, "FAIL: Classification accuracy {:.2}% < 95%", accuracy * 100.0); } // ======================================================================== @@ -636,19 +607,16 @@ mod tests { eprintln!("\n Full 7-component density mismatch:"); let bundler7 = SurroundBundler::new(7); let components: [Fingerprint<256>; 7] = [ - random_fingerprint_256(10), // S: 50% fill - random_fingerprint_256(20), // P: 50% fill - random_fingerprint_256(30), // O: 50% fill - sparse_fingerprint_256(40, 0.05), // T: 5% fill - sparse_fingerprint_256(50, 0.05), // K: 5% fill - sparse_fingerprint_256(60, 0.05), // M: 5% fill - sparse_fingerprint_256(70, 0.05), // L: 5% fill + random_fingerprint_256(10), // S: 50% fill + random_fingerprint_256(20), // P: 50% fill + random_fingerprint_256(30), // O: 50% fill + sparse_fingerprint_256(40, 0.05), // T: 5% fill + sparse_fingerprint_256(50, 0.05), // K: 5% fill + sparse_fingerprint_256(60, 0.05), // M: 5% fill + sparse_fingerprint_256(70, 0.05), // L: 5% fill ]; - let f64_atoms: Vec> = components - .iter() - .map(fingerprint_to_f64) - .collect(); + let f64_atoms: Vec> = components.iter().map(fingerprint_to_f64).collect(); let bundle = bundler7.bundle_raw(&f64_atoms); let mut all_classified = true; @@ -668,16 +636,10 @@ mod tests { if !ok { all_classified = false; } - eprintln!( - " {}: fid={:.4}, classified={} (best_match={})", - name, fid, ok, COMPONENT_NAMES[best_match] - ); + eprintln!(" {}: fid={:.4}, classified={} (best_match={})", name, fid, ok, COMPONENT_NAMES[best_match]); } - assert!( - all_classified, - "FAIL: Density-mismatched components not all correctly classified" - ); + assert!(all_classified, "FAIL: Density-mismatched components not all correctly classified"); } // ======================================================================== @@ -705,10 +667,7 @@ mod tests { sparse_fingerprint_256(seed + 5, 0.1), sparse_fingerprint_256(seed + 6, 0.1), ]; - let f64_atoms: Vec> = components - .iter() - .map(fingerprint_to_f64) - .collect(); + let f64_atoms: Vec> = components.iter().map(fingerprint_to_f64).collect(); let sm = SurroundMetadata::from_components(&components, &bundler); all_f64.push(f64_atoms); all_components.push(components); @@ -734,8 +693,7 @@ mod tests { } all_components[i][0] = Fingerprint::from_words(blended_words); all_f64[i][0] = fingerprint_to_f64(&all_components[i][0]); - all_bundles[i] = - SurroundMetadata::from_components(&all_components[i], &bundler); + all_bundles[i] = SurroundMetadata::from_components(&all_components[i], &bundler); } // Query: find nodes similar in S component @@ -775,9 +733,7 @@ mod tests { let mut full_separate: Vec<(usize, u32)> = (0..n_nodes) .map(|i| { let total: u32 = (0..N_ATOMS) - .map(|c| { - all_components[query_idx][c].hamming_distance(&all_components[i][c]) - }) + .map(|c| all_components[query_idx][c].hamming_distance(&all_components[i][c])) .sum(); (i, total) }) @@ -861,20 +817,11 @@ mod tests { finest = d; } } - let bf16 = bf16_from_projections( - &bands, - finest, - 16384, - super::super::causality::CausalityDirection::None, - ); + let bf16 = bf16_from_projections(&bands, finest, 16384, super::super::causality::CausalityDirection::None); // BF16 distance: simple XOR popcount - let bf16_zero = bf16_from_projections( - &[Band::Foveal; 7], - 0, - 16384, - super::super::causality::CausalityDirection::None, - ); + let bf16_zero = + bf16_from_projections(&[Band::Foveal; 7], 0, 16384, super::super::causality::CausalityDirection::None); let bf16_d = (bf16 ^ bf16_zero).count_ones(); bf16_dists.push(bf16_d as f64); } @@ -902,11 +849,7 @@ mod tests { // it means BF16 is NOT a compressed version of the bundle. // They are complementary cascade levels, not redundant ones. // The assertion below documents this finding — any positive correlation is bonus. - assert!( - rho > -0.50, - "BF16-bundle correlation is strongly negative: ρ={:.4} (unexpected)", - rho - ); + assert!(rho > -0.50, "BF16-bundle correlation is strongly negative: ρ={:.4} (unexpected)", rho); } // ======================================================================== @@ -923,8 +866,9 @@ mod tests { let n_nodes = clusters * per_cluster; // Create ground-truth clusters: each cluster shares similar S planes - let cluster_centers: Vec> = - (0..clusters).map(|c| random_fingerprint_256(c as u64 * 10000 + 7777)).collect(); + let cluster_centers: Vec> = (0..clusters) + .map(|c| random_fingerprint_256(c as u64 * 10000 + 7777)) + .collect(); let mut labels = Vec::with_capacity(n_nodes); let mut bundle_bytes = Vec::with_capacity(n_nodes * 1024); @@ -1032,10 +976,7 @@ mod tests { eprintln!(" Content CLAM tree nodes: {}", content_tree.nodes.len()); eprintln!(" Bundle CLAM tree nodes: {}", bundle_tree.nodes.len()); eprintln!(" k-NN Recall@{} (content vs bundle): {:.2}", k, knn_recall); - eprintln!( - " Cluster purity: content={:.2}, bundle={:.2}", - content_purity, bundle_purity - ); + eprintln!(" Cluster purity: content={:.2}, bundle={:.2}", content_purity, bundle_purity); if bundle_purity > 0.9 * content_purity { eprintln!(" Verdict: GO — bundle CLAM ≥90% of content CLAM purity"); @@ -1118,18 +1059,10 @@ mod tests { finest = d; } } - let bf16_q = bf16_from_projections( - &[Band::Foveal; 7], - 0, - 16384, - super::super::causality::CausalityDirection::None, - ); - let bf16_c = bf16_from_projections( - &bands, - finest, - 16384, - super::super::causality::CausalityDirection::None, - ); + let bf16_q = + bf16_from_projections(&[Band::Foveal; 7], 0, 16384, super::super::causality::CausalityDirection::None); + let bf16_c = + bf16_from_projections(&bands, finest, 16384, super::super::causality::CausalityDirection::None); let bf16_d = (bf16_q ^ bf16_c).count_ones(); rank_bf16.push((ci, bf16_d as f64)); @@ -1182,8 +1115,7 @@ mod tests { let k = 10; let full_top_k: Vec = rank_full.iter().take(k).map(|&(i, _)| i).collect(); let bundle_top_k: Vec = rank_bundle.iter().take(k * 2).map(|&(i, _)| i).collect(); - let merkle_top_k: Vec = - rank_merkle.iter().take(k * 5).map(|&(i, _)| i).collect(); + let merkle_top_k: Vec = rank_merkle.iter().take(k * 5).map(|&(i, _)| i).collect(); let bundle_recall = full_top_k .iter() @@ -1197,32 +1129,15 @@ mod tests { / k as f64; eprintln!("\n Cascade recall (preserving full top-{}):", k); - eprintln!( - " Bundle top-{} contains {:.0}% of full top-{}", - k * 2, - bundle_recall * 100.0, - k - ); - eprintln!( - " Merkle top-{} contains {:.0}% of full top-{}", - k * 5, - merkle_recall * 100.0, - k - ); + eprintln!(" Bundle top-{} contains {:.0}% of full top-{}", k * 2, bundle_recall * 100.0, k); + eprintln!(" Merkle top-{} contains {:.0}% of full top-{}", k * 5, merkle_recall * 100.0, k); // Assess overall cascade coherence - let cascade_ok = rho_23 > 0.30; // Bundle→Full must be positively correlated - eprintln!( - "\n Overall cascade coherence: {}", - if cascade_ok { "GO" } else { "NO-GO" } - ); + let cascade_ok = rho_23 > 0.30; // Bundle→Full must be positively correlated + eprintln!("\n Overall cascade coherence: {}", if cascade_ok { "GO" } else { "NO-GO" }); // At minimum, bundle→full must show positive correlation - assert!( - rho_23 > 0.0, - "Bundle→Full rank correlation is negative: {:.4}", - rho_23 - ); + assert!(rho_23 > 0.0, "Bundle→Full rank correlation is negative: {:.4}", rho_23); } /// Helper: convert sorted ranking to an array indexed by candidate ID @@ -1253,12 +1168,7 @@ mod tests { .zip(recovered.iter()) .map(|(a, b)| (a - b).abs()) .sum::(); - assert!( - err < 1e-10, - "Rotation roundtrip error for atom {}: {}", - atom_idx, - err - ); + assert!(err < 1e-10, "Rotation roundtrip error for atom {}: {}", atom_idx, err); } } @@ -1274,10 +1184,6 @@ mod tests { // The f64 vector should be unit-normalized let norm: f64 = f64_v.iter().map(|x| x * x).sum::().sqrt(); - assert!( - (norm - 1.0).abs() < 0.01, - "f64 vector not normalized: {}", - norm - ); + assert!((norm - 1.0).abs() < 0.01, "f64 vector not normalized: {}", norm); } } diff --git a/src/hpc/tekamolo.rs b/src/hpc/tekamolo.rs index d1264ad8..ff66f633 100644 --- a/src/hpc/tekamolo.rs +++ b/src/hpc/tekamolo.rs @@ -70,26 +70,15 @@ pub struct CrystalizedSentence { fn verbs() -> HashSet<&'static str> { [ - "is", "are", "was", "were", "am", "be", "been", "being", - "have", "has", "had", "do", "does", "did", - "will", "would", "shall", "should", "can", "could", "may", "might", "must", - "go", "goes", "went", "come", "came", - "get", "gets", "got", "make", "makes", "made", - "take", "takes", "took", "give", "gives", "gave", - "say", "says", "said", "see", "sees", "saw", - "know", "knows", "knew", "think", "thinks", "thought", - "want", "wants", "wanted", - "run", "runs", "ran", "walk", "walks", "walked", - "sit", "sits", "sat", "stand", "stands", "stood", - "eat", "eats", "ate", "drink", "drinks", "drank", - "sleep", "sleeps", "slept", - "read", "reads", "write", "writes", "wrote", - "speak", "speaks", "spoke", "hear", "hears", "heard", - "feel", "feels", "felt", "live", "lives", "lived", - "die", "dies", "died", "love", "loves", "loved", - "hate", "hates", "hated", "play", "plays", "played", - "work", "works", "worked", "move", "moves", "moved", - "happen", "happens", "happened", "touch", "touched", + "is", "are", "was", "were", "am", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", + "would", "shall", "should", "can", "could", "may", "might", "must", "go", "goes", "went", "come", "came", + "get", "gets", "got", "make", "makes", "made", "take", "takes", "took", "give", "gives", "gave", "say", "says", + "said", "see", "sees", "saw", "know", "knows", "knew", "think", "thinks", "thought", "want", "wants", "wanted", + "run", "runs", "ran", "walk", "walks", "walked", "sit", "sits", "sat", "stand", "stands", "stood", "eat", + "eats", "ate", "drink", "drinks", "drank", "sleep", "sleeps", "slept", "read", "reads", "write", "writes", + "wrote", "speak", "speaks", "spoke", "hear", "hears", "heard", "feel", "feels", "felt", "live", "lives", + "lived", "die", "dies", "died", "love", "loves", "loved", "hate", "hates", "hated", "play", "plays", "played", + "work", "works", "worked", "move", "moves", "moved", "happen", "happens", "happened", "touch", "touched", "left", "flew", ] .iter() @@ -99,10 +88,9 @@ fn verbs() -> HashSet<&'static str> { fn temporal_keywords() -> HashSet<&'static str> { [ - "yesterday", "today", "tomorrow", "now", "then", "always", "never", - "before", "after", "when", "while", "during", "already", "soon", - "later", "recently", "once", "often", "sometimes", "usually", - "morning", "evening", "night", + "yesterday", "today", "tomorrow", "now", "then", "always", "never", "before", "after", "when", "while", + "during", "already", "soon", "later", "recently", "once", "often", "sometimes", "usually", "morning", + "evening", "night", ] .iter() .copied() @@ -111,8 +99,7 @@ fn temporal_keywords() -> HashSet<&'static str> { fn kausal_keywords() -> HashSet<&'static str> { [ - "because", "since", "therefore", "hence", "thus", "so", - "consequently", "due", "reason", "cause", "why", + "because", "since", "therefore", "hence", "thus", "so", "consequently", "due", "reason", "cause", "why", ] .iter() .copied() @@ -121,9 +108,8 @@ fn kausal_keywords() -> HashSet<&'static str> { fn modal_keywords() -> HashSet<&'static str> { [ - "quickly", "slowly", "carefully", "easily", "well", "badly", - "hard", "gently", "loudly", "quietly", "happily", "sadly", - "with", "without", "by", + "quickly", "slowly", "carefully", "easily", "well", "badly", "hard", "gently", "loudly", "quietly", "happily", + "sadly", "with", "without", "by", ] .iter() .copied() @@ -132,11 +118,8 @@ fn modal_keywords() -> HashSet<&'static str> { fn lokal_keywords() -> HashSet<&'static str> { [ - "here", "there", "above", "below", "near", "far", - "inside", "outside", "between", "behind", "front", - "under", "over", "in", "on", "at", "from", "to", - "into", "through", "across", "around", "along", - "up", "down", + "here", "there", "above", "below", "near", "far", "inside", "outside", "between", "behind", "front", "under", + "over", "in", "on", "at", "from", "to", "into", "through", "across", "around", "along", "up", "down", ] .iter() .copied() @@ -152,9 +135,8 @@ fn determiners() -> HashSet<&'static str> { fn lokal_prepositions() -> HashSet<&'static str> { [ - "in", "on", "at", "from", "to", "into", "through", "across", - "around", "along", "up", "down", "under", "over", "between", - "behind", "above", "below", + "in", "on", "at", "from", "to", "into", "through", "across", "around", "along", "up", "down", "under", "over", + "between", "behind", "above", "below", ] .iter() .copied() @@ -168,7 +150,7 @@ fn lokal_prepositions() -> HashSet<&'static str> { /// Strip common trailing punctuation from a token for matching purposes. fn normalize_token(tok: &str) -> String { tok.trim_end_matches(['.', ',', '!', '?', ';', ':']) - .to_lowercase() + .to_lowercase() } /// Classify a single (normalized) token into an adverbial slot, if any. @@ -269,7 +251,11 @@ pub fn tekamolo_parse(sentence: &str) -> Vec { } // 2d. Post-predicate tokens: classify adverbials and lokal phrases; rest is object. - let start = if let Some(pi) = pred_idx { pi + 1 } else { raw_tokens.len() }; + let start = if let Some(pi) = pred_idx { + pi + 1 + } else { + raw_tokens.len() + }; let mut obj_words: Vec<&str> = Vec::new(); let mut i = start; while i < raw_tokens.len() { @@ -447,10 +433,22 @@ mod tests { #[test] fn test_reorder_tekamolo() { let slots = vec![ - SlotEntry { slot: TekmoloSlot::Subject, text: "cat".into() }, - SlotEntry { slot: TekmoloSlot::Predicate, text: "sat".into() }, - SlotEntry { slot: TekmoloSlot::Lokal, text: "on the mat".into() }, - SlotEntry { slot: TekmoloSlot::Temporal, text: "yesterday".into() }, + SlotEntry { + slot: TekmoloSlot::Subject, + text: "cat".into(), + }, + SlotEntry { + slot: TekmoloSlot::Predicate, + text: "sat".into(), + }, + SlotEntry { + slot: TekmoloSlot::Lokal, + text: "on the mat".into(), + }, + SlotEntry { + slot: TekmoloSlot::Temporal, + text: "yesterday".into(), + }, ]; let reordered = reorder_tekamolo(&slots); assert_eq!(reordered[0].slot, TekmoloSlot::Temporal); @@ -486,9 +484,18 @@ mod tests { #[test] fn test_spo_extraction() { let slots = vec![ - SlotEntry { slot: TekmoloSlot::Subject, text: "dog".into() }, - SlotEntry { slot: TekmoloSlot::Predicate, text: "ran".into() }, - SlotEntry { slot: TekmoloSlot::Object, text: "fast".into() }, + SlotEntry { + slot: TekmoloSlot::Subject, + text: "dog".into(), + }, + SlotEntry { + slot: TekmoloSlot::Predicate, + text: "ran".into(), + }, + SlotEntry { + slot: TekmoloSlot::Object, + text: "fast".into(), + }, ]; let spo = spo_extract(&slots); assert_eq!(spo.subject.as_deref(), Some("dog")); diff --git a/src/hpc/udf_kernels.rs b/src/hpc/udf_kernels.rs index d656d0a6..30735ffa 100644 --- a/src/hpc/udf_kernels.rs +++ b/src/hpc/udf_kernels.rs @@ -88,24 +88,12 @@ pub struct SpoDistanceResult { /// assert_eq!(result.combined_dist, Some(0)); /// ``` pub fn udf_spo_distance( - s1: &[u8], - p1: &[u8], - o1: &[u8], - s2: &[u8], - p2: &[u8], - o2: &[u8], + s1: &[u8], p1: &[u8], o1: &[u8], s2: &[u8], p2: &[u8], o2: &[u8], ) -> Result { use super::fingerprint::Fingerprint; const PLANE_BYTES: usize = 2048; - for (name, slice) in [ - ("s1", s1), - ("p1", p1), - ("o1", o1), - ("s2", s2), - ("p2", p2), - ("o2", o2), - ] { + for (name, slice) in [("s1", s1), ("p1", p1), ("o1", o1), ("s2", s2), ("p2", p2), ("o2", o2)] { if slice.len() != PLANE_BYTES { return Err(if name.starts_with('s') { "udf_spo_distance: subject plane must be 2048 bytes" @@ -304,13 +292,13 @@ pub fn factorize_spo(node: &mut Node) -> [u64; 8] { let o_pop = popcount_raw(o_bits); [ - 0, // empty - s_pop, // S - p_pop, // P - o_pop, // O - s_pop.saturating_add(p_pop), // SP - p_pop.saturating_add(o_pop), // PO - s_pop.saturating_add(o_pop), // SO + 0, // empty + s_pop, // S + p_pop, // P + o_pop, // O + s_pop.saturating_add(p_pop), // SP + p_pop.saturating_add(o_pop), // PO + s_pop.saturating_add(o_pop), // SO s_pop.saturating_add(p_pop).saturating_add(o_pop), // SPO ] } @@ -451,11 +439,7 @@ pub fn causal_edge(node_a: &mut Node, node_b: &mut Node) -> CausalityDirection { // Vote: majority of the three dimensions decides direction let mut forward_count = 0i32; let mut backward_count = 0i32; - for dir in [ - decomposition.warmth_dir, - decomposition.social_dir, - decomposition.sacredness_dir, - ] { + for dir in [decomposition.warmth_dir, decomposition.social_dir, decomposition.sacredness_dir] { match dir { CausalityDirection::Forward => forward_count += 1, CausalityDirection::Backward => backward_count += 1, @@ -664,7 +648,7 @@ mod tests { assert!(terms[1] > 0); // S should have bits assert!(terms[2] > 0); // P should have bits assert!(terms[3] > 0); // O should have bits - // Combination norms are sums + // Combination norms are sums assert_eq!(terms[4], terms[1] + terms[2]); // SP = S + P assert_eq!(terms[5], terms[2] + terms[3]); // PO = P + O assert_eq!(terms[6], terms[1] + terms[3]); // SO = S + O @@ -757,12 +741,7 @@ mod tests { let mut b = Node::random(99); let dir = causal_edge(&mut a, &mut b); // Should produce a definite direction for different nodes - assert!(matches!( - dir, - CausalityDirection::Forward - | CausalityDirection::Backward - | CausalityDirection::None - )); + assert!(matches!(dir, CausalityDirection::Forward | CausalityDirection::Backward | CausalityDirection::None)); } #[test] diff --git a/src/hpc/vml.rs b/src/hpc/vml.rs index fe8be67c..34c16545 100644 --- a/src/hpc/vml.rs +++ b/src/hpc/vml.rs @@ -119,11 +119,19 @@ pub fn vdabs(x: &[f64], out: &mut [f64]) { pub fn vsadd(a: &[f32], b: &[f32], out: &mut [f32]) { let n = a.len().min(b.len()).min(out.len()); let (a, b, out) = (&a[..n], &b[..n], &mut out[..n]); - for ((a_chunk, b_chunk), out_chunk) in a.chunks_exact(16).zip(b.chunks_exact(16)).zip(out.chunks_exact_mut(16)) { + for ((a_chunk, b_chunk), out_chunk) in a + .chunks_exact(16) + .zip(b.chunks_exact(16)) + .zip(out.chunks_exact_mut(16)) + { (F32x16::from_slice(a_chunk) + F32x16::from_slice(b_chunk)).copy_to_slice(out_chunk); } let tail_start = n - n % 16; - for ((&av, &bv), o) in a[tail_start..].iter().zip(b[tail_start..].iter()).zip(out[tail_start..].iter_mut()) { + for ((&av, &bv), o) in a[tail_start..] + .iter() + .zip(b[tail_start..].iter()) + .zip(out[tail_start..].iter_mut()) + { *o = av + bv; } } @@ -134,11 +142,19 @@ pub fn vsadd(a: &[f32], b: &[f32], out: &mut [f32]) { pub fn vsmul(a: &[f32], b: &[f32], out: &mut [f32]) { let n = a.len().min(b.len()).min(out.len()); let (a, b, out) = (&a[..n], &b[..n], &mut out[..n]); - for ((a_chunk, b_chunk), out_chunk) in a.chunks_exact(16).zip(b.chunks_exact(16)).zip(out.chunks_exact_mut(16)) { + for ((a_chunk, b_chunk), out_chunk) in a + .chunks_exact(16) + .zip(b.chunks_exact(16)) + .zip(out.chunks_exact_mut(16)) + { (F32x16::from_slice(a_chunk) * F32x16::from_slice(b_chunk)).copy_to_slice(out_chunk); } let tail_start = n - n % 16; - for ((&av, &bv), o) in a[tail_start..].iter().zip(b[tail_start..].iter()).zip(out[tail_start..].iter_mut()) { + for ((&av, &bv), o) in a[tail_start..] + .iter() + .zip(b[tail_start..].iter()) + .zip(out[tail_start..].iter_mut()) + { *o = av * bv; } } @@ -149,11 +165,19 @@ pub fn vsmul(a: &[f32], b: &[f32], out: &mut [f32]) { pub fn vsdiv(a: &[f32], b: &[f32], out: &mut [f32]) { let n = a.len().min(b.len()).min(out.len()); let (a, b, out) = (&a[..n], &b[..n], &mut out[..n]); - for ((a_chunk, b_chunk), out_chunk) in a.chunks_exact(16).zip(b.chunks_exact(16)).zip(out.chunks_exact_mut(16)) { + for ((a_chunk, b_chunk), out_chunk) in a + .chunks_exact(16) + .zip(b.chunks_exact(16)) + .zip(out.chunks_exact_mut(16)) + { (F32x16::from_slice(a_chunk) / F32x16::from_slice(b_chunk)).copy_to_slice(out_chunk); } let tail_start = n - n % 16; - for ((&av, &bv), o) in a[tail_start..].iter().zip(b[tail_start..].iter()).zip(out[tail_start..].iter_mut()) { + for ((&av, &bv), o) in a[tail_start..] + .iter() + .zip(b[tail_start..].iter()) + .zip(out[tail_start..].iter_mut()) + { *o = av / bv; } } @@ -213,14 +237,22 @@ pub fn vscos(x: &[f32], out: &mut [f32]) { pub fn vspow(a: &[f32], b: &[f32], out: &mut [f32]) { let n = a.len().min(b.len()).min(out.len()); let (a, b, out) = (&a[..n], &b[..n], &mut out[..n]); - for ((a_chunk, b_chunk), out_chunk) in a.chunks_exact(16).zip(b.chunks_exact(16)).zip(out.chunks_exact_mut(16)) { + for ((a_chunk, b_chunk), out_chunk) in a + .chunks_exact(16) + .zip(b.chunks_exact(16)) + .zip(out.chunks_exact_mut(16)) + { let va = F32x16::from_slice(a_chunk); let vb = F32x16::from_slice(b_chunk); // a^b = exp(b * ln(a)) simd_exp_f32(vb * simd_ln_f32(va)).copy_to_slice(out_chunk); } let tail_start = n - n % 16; - for ((&av, &bv), o) in a[tail_start..].iter().zip(b[tail_start..].iter()).zip(out[tail_start..].iter_mut()) { + for ((&av, &bv), o) in a[tail_start..] + .iter() + .zip(b[tail_start..].iter()) + .zip(out[tail_start..].iter_mut()) + { *o = av.powf(bv); } } @@ -502,7 +534,8 @@ mod tests { // ── QUANTIZE: f64[17] → i16[17] (what Base17 stores) ── let fp_scale = 1000.0; - let coefficients_i16: Vec = coefficients_f64.iter() + let coefficients_i16: Vec = coefficients_f64 + .iter() .map(|&v| (v * fp_scale).round().clamp(-32768.0, 32767.0) as i16) .collect(); @@ -555,12 +588,12 @@ mod tests { // Compare: golden-step 17D projection vs random 17D projection // on synthetic weight-like data (approximate Gaussian). // Measures Spearman ρ of pairwise distances. - + let d = 256; // weight vector dimension (small for test speed) - let n = 50; // number of vectors to compare + let n = 50; // number of vectors to compare let base_dim = 17; let golden_step = 11; - + // Generate weight-like vectors (deterministic, Gaussian-ish) let vectors: Vec> = (0..n) .map(|i| { @@ -569,21 +602,24 @@ mod tests { .collect() }) .collect(); - + // Ground truth: pairwise L2 distances in full d-D space let mut gt_distances = Vec::new(); for i in 0..n { for j in (i + 1)..n { - let dist: f64 = vectors[i].iter().zip(&vectors[j]) + let dist: f64 = vectors[i] + .iter() + .zip(&vectors[j]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); gt_distances.push(dist); } } - + // Golden-step projection: project each vector to 17D - let golden_projected: Vec> = vectors.iter() + let golden_projected: Vec> = vectors + .iter() .map(|v| { let n_octaves = (d + base_dim - 1) / base_dim; let mut sum = vec![0.0f64; base_dim]; @@ -597,12 +633,13 @@ mod tests { } } } - sum.iter().zip(&count) + sum.iter() + .zip(&count) .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) .collect() }) .collect(); - + // Random projection: use a deterministic "random" 17×d matrix let random_matrix: Vec> = (0..base_dim) .map(|i| { @@ -611,23 +648,25 @@ mod tests { .collect() }) .collect(); - - let random_projected: Vec> = vectors.iter() + + let random_projected: Vec> = vectors + .iter() .map(|v| { - random_matrix.iter() - .map(|row| { - row.iter().zip(v).map(|(r, x)| r * x).sum::() - }) + random_matrix + .iter() + .map(|row| row.iter().zip(v).map(|(r, x)| r * x).sum::()) .collect() }) .collect(); - + // Compute pairwise distances in both projected spaces let golden_distances: Vec = { let mut dists = Vec::new(); for i in 0..n { for j in (i + 1)..n { - let dist: f64 = golden_projected[i].iter().zip(&golden_projected[j]) + let dist: f64 = golden_projected[i] + .iter() + .zip(&golden_projected[j]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); @@ -636,12 +675,14 @@ mod tests { } dists }; - + let random_distances: Vec = { let mut dists = Vec::new(); for i in 0..n { for j in (i + 1)..n { - let dist: f64 = random_projected[i].iter().zip(&random_projected[j]) + let dist: f64 = random_projected[i] + .iter() + .zip(&random_projected[j]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); @@ -650,7 +691,7 @@ mod tests { } dists }; - + // Compute Spearman ρ: rank correlation between GT and projected distances fn spearman_rho(a: &[f64], b: &[f64]) -> f64 { let n = a.len(); @@ -668,10 +709,12 @@ mod tests { var_a += da * da; var_b += db * db; } - if var_a < 1e-10 || var_b < 1e-10 { return 0.0; } + if var_a < 1e-10 || var_b < 1e-10 { + return 0.0; + } cov / (var_a * var_b).sqrt() } - + fn ranks(values: &[f64]) -> Vec { let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect(); indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); @@ -681,15 +724,15 @@ mod tests { } result } - + let rho_golden = spearman_rho(>_distances, &golden_distances); let rho_random = spearman_rho(>_distances, &random_distances); - + eprintln!("=== Projection Quality (Spearman ρ) ==="); eprintln!(" Golden-step 17D: ρ = {:.4}", rho_golden); eprintln!(" Random 17D: ρ = {:.4}", rho_random); eprintln!(" Δ (golden - random): {:.4}", rho_golden - rho_random); - + // Both should preserve SOME ranking (ρ > 0) assert!(rho_golden > 0.0, "golden-step ρ should be positive"); assert!(rho_random > 0.0, "random ρ should be positive"); @@ -720,28 +763,31 @@ mod tests { let v: Vec = (0..d) .map(|j| { let off = (i * d + j) * 4; - f32::from_le_bytes([float_data[off], float_data[off+1], float_data[off+2], float_data[off+3]]) as f64 + f32::from_le_bytes([float_data[off], float_data[off + 1], float_data[off + 2], float_data[off + 3]]) + as f64 }) .collect(); vectors.push(v); } - + let n = vectors.len(); eprintln!("Loaded {} vectors of dim {} from tiny-imagenet", n, d); assert!(n >= 50, "Need at least 50 vectors"); - + // Use first 100 for speed let n = n.min(100); let vectors = &vectors[..n]; - + let base_dim = 17; let golden_step = 11; // Ground truth: pairwise L2 distances let mut gt_distances = Vec::new(); for i in 0..n { - for j in (i+1)..n { - let dist: f64 = vectors[i].iter().zip(&vectors[j]) + for j in (i + 1)..n { + let dist: f64 = vectors[i] + .iter() + .zip(&vectors[j]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); @@ -750,7 +796,8 @@ mod tests { } // Golden-step projection - let golden_projected: Vec> = vectors.iter() + let golden_projected: Vec> = vectors + .iter() .map(|v| { let n_octaves = (d + base_dim - 1) / base_dim; let mut sum = vec![0.0f64; base_dim]; @@ -758,28 +805,51 @@ mod tests { for octave in 0..n_octaves { for bi in 0..base_dim { let dim = octave * base_dim + ((bi * golden_step) % base_dim); - if dim < d { sum[bi] += v[dim]; count[bi] += 1; } + if dim < d { + sum[bi] += v[dim]; + count[bi] += 1; + } } } - sum.iter().zip(&count).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&count) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }) .collect(); // Random projection let random_matrix: Vec> = (0..base_dim) - .map(|i| (0..d).map(|j| ((i * 7919 + j * 104729) as f64 * 0.00001).sin()).collect()) + .map(|i| { + (0..d) + .map(|j| ((i * 7919 + j * 104729) as f64 * 0.00001).sin()) + .collect() + }) .collect(); - let random_projected: Vec> = vectors.iter() - .map(|v| random_matrix.iter().map(|row| row.iter().zip(v).map(|(r, x)| r * x).sum::()).collect()) + let random_projected: Vec> = vectors + .iter() + .map(|v| { + random_matrix + .iter() + .map(|row| row.iter().zip(v).map(|(r, x)| r * x).sum::()) + .collect() + }) .collect(); // Simple mean projection (average every 17 consecutive dims) - let mean_projected: Vec> = vectors.iter() + let mean_projected: Vec> = vectors + .iter() .map(|v| { - (0..base_dim).map(|bi| { - let chunk: Vec = (bi..d).step_by(base_dim).map(|i| v[i]).collect(); - if chunk.is_empty() { 0.0 } else { chunk.iter().sum::() / chunk.len() as f64 } - }).collect() + (0..base_dim) + .map(|bi| { + let chunk: Vec = (bi..d).step_by(base_dim).map(|i| v[i]).collect(); + if chunk.is_empty() { + 0.0 + } else { + chunk.iter().sum::() / chunk.len() as f64 + } + }) + .collect() }) .collect(); @@ -787,10 +857,17 @@ mod tests { fn pairwise_l2(proj: &[Vec]) -> Vec { let n = proj.len(); let mut dists = Vec::new(); - for i in 0..n { for j in (i+1)..n { - let d: f64 = proj[i].iter().zip(&proj[j]).map(|(a,b)| (a-b)*(a-b)).sum::().sqrt(); - dists.push(d); - }} + for i in 0..n { + for j in (i + 1)..n { + let d: f64 = proj[i] + .iter() + .zip(&proj[j]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + dists.push(d); + } + } dists } @@ -804,19 +881,28 @@ mod tests { let mut idx: Vec<(usize, f64)> = v.iter().copied().enumerate().collect(); idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let mut r = vec![0.0; v.len()]; - for (rank, (i, _)) in idx.into_iter().enumerate() { r[i] = rank as f64; } + for (rank, (i, _)) in idx.into_iter().enumerate() { + r[i] = rank as f64; + } r } - let ra = ranks(a); let rb = ranks(b); + let ra = ranks(a); + let rb = ranks(b); let n = a.len() as f64; let ma: f64 = ra.iter().sum::() / n; let mb: f64 = rb.iter().sum::() / n; let (mut cov, mut va, mut vb) = (0.0, 0.0, 0.0); for i in 0..a.len() { let (da, db) = (ra[i] - ma, rb[i] - mb); - cov += da * db; va += da * da; vb += db * db; + cov += da * db; + va += da * da; + vb += db * db; + } + if va < 1e-10 || vb < 1e-10 { + 0.0 + } else { + cov / (va * vb).sqrt() } - if va < 1e-10 || vb < 1e-10 { 0.0 } else { cov / (va * vb).sqrt() } } let rho_golden = spearman(>_distances, &golden_dists); @@ -851,7 +937,10 @@ mod tests { // Compare: 1/3 grid line sampling vs golden-step on tiny-imagenet let bytes = match std::fs::read("/tmp/tiny_imagenet_200.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_200.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_200.bin not found"); + return; + } }; let d = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; // 12288 @@ -869,21 +958,21 @@ mod tests { let v: Vec = (0..d) .map(|j| { let off = (i * d + j) * 4; - f32::from_le_bytes([float_data[off], float_data[off+1], float_data[off+2], float_data[off+3]]) as f64 + f32::from_le_bytes([float_data[off], float_data[off + 1], float_data[off + 2], float_data[off + 3]]) + as f64 }) .collect(); vectors.push(v); } // Helper: extract pixel at (row, col, channel) from flat vector - let pixel = |v: &[f64], r: usize, c: usize, ch: usize| -> f64 { - v[r * img_w * channels + c * channels + ch] - }; + let pixel = |v: &[f64], r: usize, c: usize, ch: usize| -> f64 { v[r * img_w * channels + c * channels + ch] }; // ── Projection 1: Golden-step 17D (baseline) ── let base_dim = 17; let golden_step = 11; - let golden_proj: Vec> = vectors.iter() + let golden_proj: Vec> = vectors + .iter() .map(|v| { let n_octaves = (d + base_dim - 1) / base_dim; let mut sum = vec![0.0f64; base_dim]; @@ -891,20 +980,27 @@ mod tests { for octave in 0..n_octaves { for bi in 0..base_dim { let dim = octave * base_dim + ((bi * golden_step) % base_dim); - if dim < d { sum[bi] += v[dim]; count[bi] += 1; } + if dim < d { + sum[bi] += v[dim]; + count[bi] += 1; + } } } - sum.iter().zip(&count).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&count) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }) .collect(); // ── Projection 2: 1/3 + 2/3 grid lines (4 lines × 64 × 3 = 768D) ── - let grid_lines_proj: Vec> = vectors.iter() + let grid_lines_proj: Vec> = vectors + .iter() .map(|v| { let mut features = Vec::with_capacity(768); // Horizontal lines at row 1/3 and 2/3 - let r1 = img_h / 3; // row 21 - let r2 = 2 * img_h / 3; // row 43 + let r1 = img_h / 3; // row 21 + let r2 = 2 * img_h / 3; // row 43 for &r in &[r1, r2] { for c in 0..img_w { for ch in 0..channels { @@ -927,19 +1023,24 @@ mod tests { .collect(); // ── Projection 3: 1/2 + 1/3 + 2/3 grid (6 lines × 64 × 3 = 1152D) ── - let full_grid_proj: Vec> = vectors.iter() + let full_grid_proj: Vec> = vectors + .iter() .map(|v| { let mut features = Vec::with_capacity(1152); // Horizontal: 1/3, 1/2, 2/3 for &r in &[img_h / 3, img_h / 2, 2 * img_h / 3] { for c in 0..img_w { - for ch in 0..channels { features.push(pixel(v, r, c, ch)); } + for ch in 0..channels { + features.push(pixel(v, r, c, ch)); + } } } // Vertical: 1/3, 1/2, 2/3 for &c in &[img_w / 3, img_w / 2, 2 * img_w / 3] { for r in 0..img_h { - for ch in 0..channels { features.push(pixel(v, r, c, ch)); } + for ch in 0..channels { + features.push(pixel(v, r, c, ch)); + } } } features @@ -947,12 +1048,15 @@ mod tests { .collect(); // ── Projection 4: 4 intersection points only (4 × 3 = 12D) ── - let intersections_proj: Vec> = vectors.iter() + let intersections_proj: Vec> = vectors + .iter() .map(|v| { let mut features = Vec::with_capacity(12); for &r in &[img_h / 3, 2 * img_h / 3] { for &c in &[img_w / 3, 2 * img_w / 3] { - for ch in 0..channels { features.push(pixel(v, r, c, ch)); } + for ch in 0..channels { + features.push(pixel(v, r, c, ch)); + } } } features @@ -961,18 +1065,32 @@ mod tests { // ── Ground truth pairwise distances ── let mut gt_dists = Vec::new(); - for i in 0..n { for j in (i+1)..n { - let d: f64 = vectors[i].iter().zip(&vectors[j]).map(|(a,b)| (a-b)*(a-b)).sum::().sqrt(); - gt_dists.push(d); - }} + for i in 0..n { + for j in (i + 1)..n { + let d: f64 = vectors[i] + .iter() + .zip(&vectors[j]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + gt_dists.push(d); + } + } fn pairwise_l2(proj: &[Vec]) -> Vec { let n = proj.len(); let mut d = Vec::new(); - for i in 0..n { for j in (i+1)..n { - let dist: f64 = proj[i].iter().zip(&proj[j]).map(|(a,b)| (a-b)*(a-b)).sum::().sqrt(); - d.push(dist); - }} + for i in 0..n { + for j in (i + 1)..n { + let dist: f64 = proj[i] + .iter() + .zip(&proj[j]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + d.push(dist); + } + } d } @@ -981,7 +1099,9 @@ mod tests { let mut idx: Vec<(usize, f64)> = v.iter().copied().enumerate().collect(); idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let mut r = vec![0.0; v.len()]; - for (rank, (i, _)) in idx.into_iter().enumerate() { r[i] = rank as f64; } + for (rank, (i, _)) in idx.into_iter().enumerate() { + r[i] = rank as f64; + } r } let (ra, rb) = (ranks(a), ranks(b)); @@ -990,9 +1110,15 @@ mod tests { let (mut cov, mut va, mut vb) = (0.0, 0.0, 0.0); for i in 0..a.len() { let (da, db) = (ra[i] - ma, rb[i] - mb); - cov += da * db; va += da * da; vb += db * db; + cov += da * db; + va += da * da; + vb += db * db; + } + if va < 1e-10 || vb < 1e-10 { + 0.0 + } else { + cov / (va * vb).sqrt() } - if va < 1e-10 || vb < 1e-10 { 0.0 } else { cov / (va * vb).sqrt() } } let rho_golden = spearman(>_dists, &pairwise_l2(&golden_proj)); @@ -1022,7 +1148,10 @@ mod tests { // Test: can we identify which class an image belongs to via bundle similarity? let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1033,7 +1162,7 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } // Read pixel vectors @@ -1043,7 +1172,7 @@ mod tests { let v: Vec = (0..d) .map(|j| { let off = pixel_start + (i * d + j) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 }) .collect(); vectors.push(v); @@ -1055,18 +1184,25 @@ mod tests { let img_w = 64usize; let img_h = 64usize; let ch = 3usize; - let pixel = |v: &[f64], r: usize, c: usize, channel: usize| -> f64 { - v[r * img_w * ch + c * ch + channel] - }; + let pixel = |v: &[f64], r: usize, c: usize, channel: usize| -> f64 { v[r * img_w * ch + c * ch + channel] }; - let features: Vec> = vectors.iter() + let features: Vec> = vectors + .iter() .map(|v| { let mut f = Vec::with_capacity(768); for &r in &[img_h / 3, 2 * img_h / 3] { - for c in 0..img_w { for channel in 0..ch { f.push(pixel(v, r, c, channel)); } } + for c in 0..img_w { + for channel in 0..ch { + f.push(pixel(v, r, c, channel)); + } + } } for &c in &[img_w / 3, 2 * img_w / 3] { - for r in 0..img_h { for channel in 0..ch { f.push(pixel(v, r, c, channel)); } } + for r in 0..img_h { + for channel in 0..ch { + f.push(pixel(v, r, c, channel)); + } + } } f }) @@ -1098,8 +1234,12 @@ mod tests { let mut best_class = 0; let mut best_dist = f64::MAX; for c in 0..n_classes { - if class_counts[c] == 0 { continue; } - let dist: f64 = features[i].iter().zip(&heel_archetypes[c]) + if class_counts[c] == 0 { + continue; + } + let dist: f64 = features[i] + .iter() + .zip(&heel_archetypes[c]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); @@ -1127,10 +1267,16 @@ mod tests { for oct in 0..n_oct { for bi in 0..base_dim { let dim = oct * base_dim + ((bi * golden_step) % base_dim); - if dim < fd { sum[bi] += v[dim]; cnt[bi] += 1; } + if dim < fd { + sum[bi] += v[dim]; + cnt[bi] += 1; + } } } - sum.iter().zip(&cnt).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&cnt) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }; let compressed_archetypes: Vec> = heel_archetypes.iter().map(|a| compress(a)).collect(); @@ -1141,14 +1287,23 @@ mod tests { let mut best_class = 0; let mut best_dist = f64::MAX; for c in 0..n_classes { - if class_counts[c] == 0 { continue; } - let dist: f64 = compressed_features[i].iter().zip(&compressed_archetypes[c]) + if class_counts[c] == 0 { + continue; + } + let dist: f64 = compressed_features[i] + .iter() + .zip(&compressed_archetypes[c]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); - if dist < best_dist { best_dist = dist; best_class = c; } + if dist < best_dist { + best_dist = dist; + best_class = c; + } + } + if best_class == true_label { + correct_compressed += 1; } - if best_class == true_label { correct_compressed += 1; } } let accuracy_compressed = correct_compressed as f64 / total as f64; @@ -1157,12 +1312,18 @@ mod tests { for (i, _) in labels.iter().enumerate() { let mut min_dist = f64::MAX; for c in 0..n_classes { - if class_counts[c] == 0 { continue; } - let dist: f64 = features[i].iter().zip(&heel_archetypes[c]) + if class_counts[c] == 0 { + continue; + } + let dist: f64 = features[i] + .iter() + .zip(&heel_archetypes[c]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); - if dist < min_dist { min_dist = dist; } + if dist < min_dist { + min_dist = dist; + } } max_distances.push((i, min_dist)); } @@ -1174,7 +1335,12 @@ mod tests { eprintln!(" Grid-line features (768D):"); eprintln!(" HEEL accuracy: {:.1}% ({}/{})", accuracy * 100.0, correct, total); eprintln!(" Golden-step compressed (17D = 34 bytes):"); - eprintln!(" Compressed accuracy: {:.1}% ({}/{})", accuracy_compressed * 100.0, correct_compressed, total); + eprintln!( + " Compressed accuracy: {:.1}% ({}/{})", + accuracy_compressed * 100.0, + correct_compressed, + total + ); eprintln!(" Accuracy loss from compression: {:.1}%", (accuracy - accuracy_compressed) * 100.0); eprintln!(" Top-10 outliers (CHAODA candidates):"); for (idx, dist) in max_distances.iter().take(5) { @@ -1193,10 +1359,13 @@ mod tests { // by unbinding one class archetype and checking residual against others. // // Bird/fence scenario: if unbind(image, bird) correlates with fence → both present. - + let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1206,7 +1375,7 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } let pixel_start = 12 + n * 4; @@ -1215,21 +1384,31 @@ mod tests { let ch = 3usize; // Extract grid-line features (768D) - let features: Vec> = (0..n).map(|i| { - let v_start = pixel_start + i * d * 4; - let pixel = |r: usize, c: usize, channel: usize| -> f64 { - let off = v_start + (r * img_w * ch + c * ch + channel) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 - }; - let mut f = Vec::with_capacity(768); - for &r in &[img_h / 3, 2 * img_h / 3] { - for c in 0..img_w { for channel in 0..ch { f.push(pixel(r, c, channel)); } } - } - for &c in &[img_w / 3, 2 * img_w / 3] { - for r in 0..img_h { for channel in 0..ch { f.push(pixel(r, c, channel)); } } - } - f - }).collect(); + let features: Vec> = (0..n) + .map(|i| { + let v_start = pixel_start + i * d * 4; + let pixel = |r: usize, c: usize, channel: usize| -> f64 { + let off = v_start + (r * img_w * ch + c * ch + channel) * 4; + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 + }; + let mut f = Vec::with_capacity(768); + for &r in &[img_h / 3, 2 * img_h / 3] { + for c in 0..img_w { + for channel in 0..ch { + f.push(pixel(r, c, channel)); + } + } + } + for &c in &[img_w / 3, 2 * img_w / 3] { + for r in 0..img_h { + for channel in 0..ch { + f.push(pixel(r, c, channel)); + } + } + } + f + }) + .collect(); let feat_d = features[0].len(); @@ -1237,12 +1416,16 @@ mod tests { let mut archetypes: Vec> = vec![vec![0.0; feat_d]; n_classes]; let mut counts = vec![0usize; n_classes]; for (i, &label) in labels.iter().enumerate() { - for j in 0..feat_d { archetypes[label][j] += features[i][j]; } + for j in 0..feat_d { + archetypes[label][j] += features[i][j]; + } counts[label] += 1; } for c in 0..n_classes { if counts[c] > 0 { - for j in 0..feat_d { archetypes[c][j] /= counts[c] as f64; } + for j in 0..feat_d { + archetypes[c][j] /= counts[c] as f64; + } } } @@ -1251,20 +1434,28 @@ mod tests { let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum(); let mag_a: f64 = a.iter().map(|x| x * x).sum::().sqrt(); let mag_b: f64 = b.iter().map(|x| x * x).sum::().sqrt(); - if mag_a < 1e-10 || mag_b < 1e-10 { 0.0 } else { dot / (mag_a * mag_b) } + if mag_a < 1e-10 || mag_b < 1e-10 { + 0.0 + } else { + dot / (mag_a * mag_b) + } }; // ── HIP: within-class variance (how spread is each class?) ── let mut hip_variance = vec![0.0f64; n_classes]; for (i, &label) in labels.iter().enumerate() { - let dist: f64 = features[i].iter().zip(&archetypes[label]) + let dist: f64 = features[i] + .iter() + .zip(&archetypes[label]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); hip_variance[label] += dist; } for c in 0..n_classes { - if counts[c] > 0 { hip_variance[c] /= counts[c] as f64; } + if counts[c] > 0 { + hip_variance[c] /= counts[c] as f64; + } } // ── Multi-object simulation: "subtract" one class, check residual ── @@ -1275,7 +1466,9 @@ mod tests { let mut multi_object_candidates = Vec::new(); for (i, &true_label) in labels.iter().enumerate() { // Subtract the true class archetype (simulates "removing" the primary object) - let residual: Vec = features[i].iter().zip(&archetypes[true_label]) + let residual: Vec = features[i] + .iter() + .zip(&archetypes[true_label]) .map(|(a, b)| a - b) .collect(); @@ -1283,7 +1476,9 @@ mod tests { let mut best_other_class = 0; let mut best_other_sim = f64::NEG_INFINITY; for c in 0..n_classes { - if c == true_label || counts[c] == 0 { continue; } + if c == true_label || counts[c] == 0 { + continue; + } let sim = cosine(&residual, &archetypes[c]); if sim > best_other_sim { best_other_sim = sim; @@ -1301,7 +1496,11 @@ mod tests { // features — what's left after removing the primary class IS the secondary class. let mut pair_counts: std::collections::HashMap<(usize, usize), usize> = std::collections::HashMap::new(); for &(_, primary, secondary, _) in &multi_object_candidates { - let key = if primary < secondary { (primary, secondary) } else { (secondary, primary) }; + let key = if primary < secondary { + (primary, secondary) + } else { + (secondary, primary) + }; *pair_counts.entry(key).or_insert(0) += 1; } @@ -1309,13 +1508,17 @@ mod tests { // (far from primary AND residual doesn't match secondary) let mut outliers = Vec::new(); for (i, &true_label) in labels.iter().enumerate() { - let primary_dist: f64 = features[i].iter().zip(&archetypes[true_label]) + let primary_dist: f64 = features[i] + .iter() + .zip(&archetypes[true_label]) .map(|(a, b)| (a - b) * (a - b)) .sum::() .sqrt(); - + // If far from own class AND not detected as multi-object - let is_multi = multi_object_candidates.iter().any(|&(idx, _, _, _)| idx == i); + let is_multi = multi_object_candidates + .iter() + .any(|&(idx, _, _, _)| idx == i); if primary_dist > hip_variance[true_label] * 2.0 && !is_multi { outliers.push((i, true_label, primary_dist)); } @@ -1332,8 +1535,13 @@ mod tests { } eprintln!(" CHAODA outliers (far from all archetypes): {}", outliers.len()); for (idx, label, dist) in outliers.iter().take(3) { - eprintln!(" image {} (class {}): dist={:.3} (>{:.3} threshold)", - idx, label, dist, hip_variance[*label] * 2.0); + eprintln!( + " image {} (class {}): dist={:.3} (>{:.3} threshold)", + idx, + label, + dist, + hip_variance[*label] * 2.0 + ); } eprintln!(" Per-class HIP spread (intra-class variance):"); for c in 0..n_classes { @@ -1351,10 +1559,13 @@ mod tests { // 1. Find energy centroid around each 1/3 intersection // 2. Extract detailed patch at centroid // 3. Classify patch → more precise than whole-image archetype - + let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1364,7 +1575,7 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } let pixel_start = 12 + n * 4; @@ -1375,7 +1586,7 @@ mod tests { // ── Helper: read pixel from binary ── let pixel = |img_idx: usize, r: usize, c: usize, channel: usize| -> f64 { let off = pixel_start + (img_idx * d + r * img_w * ch + c * ch + channel) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 }; // ── Helper: luminance at (r,c) ── @@ -1385,8 +1596,12 @@ mod tests { // ── Step 1: For each image, find energy centroid around each 1/3 intersection ── let patch_radius = 8usize; // 16×16 patch around each intersection - let intersections = [(img_h/3, img_w/3), (img_h/3, 2*img_w/3), - (2*img_h/3, img_w/3), (2*img_h/3, 2*img_w/3)]; + let intersections = [ + (img_h / 3, img_w / 3), + (img_h / 3, 2 * img_w / 3), + (2 * img_h / 3, img_w / 3), + (2 * img_h / 3, 2 * img_w / 3), + ]; struct FocusPoint { centroid_r: f64, @@ -1400,7 +1615,11 @@ mod tests { let mut focus_features: Vec> = Vec::with_capacity(n_use); for img_idx in 0..n_use { - let mut best_focus = FocusPoint { centroid_r: 32.0, centroid_c: 32.0, energy: 0.0 }; + let mut best_focus = FocusPoint { + centroid_r: 32.0, + centroid_c: 32.0, + energy: 0.0, + }; for &(ir, ic) in &intersections { // Compute energy centroid within patch @@ -1417,9 +1636,9 @@ mod tests { for c in c_min..c_max { let e = luma(img_idx, r, c); // Use gradient magnitude as energy (edges = objects) - let grad = if r > 0 && r < img_h-1 && c > 0 && c < img_w-1 { - let dx = luma(img_idx, r, c+1) - luma(img_idx, r, c-1); - let dy = luma(img_idx, r+1, c) - luma(img_idx, r-1, c); + let grad = if r > 0 && r < img_h - 1 && c > 0 && c < img_w - 1 { + let dx = luma(img_idx, r, c + 1) - luma(img_idx, r, c - 1); + let dy = luma(img_idx, r + 1, c) - luma(img_idx, r - 1, c); (dx * dx + dy * dy).sqrt() } else { 0.0 @@ -1432,8 +1651,16 @@ mod tests { if total_energy > best_focus.energy { best_focus = FocusPoint { - centroid_r: if total_energy > 0.0 { weighted_r / total_energy } else { ir as f64 }, - centroid_c: if total_energy > 0.0 { weighted_c / total_energy } else { ic as f64 }, + centroid_r: if total_energy > 0.0 { + weighted_r / total_energy + } else { + ir as f64 + }, + centroid_c: if total_energy > 0.0 { + weighted_c / total_energy + } else { + ic as f64 + }, energy: total_energy, }; } @@ -1467,11 +1694,17 @@ mod tests { let mut focus_archetypes: Vec> = vec![vec![0.0; feat_d]; n_classes]; let mut counts = vec![0usize; n_classes]; for (i, &label) in labels[..n_use].iter().enumerate() { - for j in 0..feat_d { focus_archetypes[label][j] += focus_features[i][j]; } + for j in 0..feat_d { + focus_archetypes[label][j] += focus_features[i][j]; + } counts[label] += 1; } for c in 0..n_classes { - if counts[c] > 0 { for j in 0..feat_d { focus_archetypes[c][j] /= counts[c] as f64; } } + if counts[c] > 0 { + for j in 0..feat_d { + focus_archetypes[c][j] /= counts[c] as f64; + } + } } // ── Step 4: Classify by nearest centroid-patch archetype ── @@ -1480,12 +1713,23 @@ mod tests { let mut best_class = 0; let mut best_dist = f64::MAX; for c in 0..n_classes { - if counts[c] == 0 { continue; } - let dist: f64 = focus_features[i].iter().zip(&focus_archetypes[c]) - .map(|(a, b)| (a - b) * (a - b)).sum::().sqrt(); - if dist < best_dist { best_dist = dist; best_class = c; } + if counts[c] == 0 { + continue; + } + let dist: f64 = focus_features[i] + .iter() + .zip(&focus_archetypes[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + if dist < best_dist { + best_dist = dist; + best_class = c; + } + } + if best_class == true_label { + correct_focus += 1; } - if best_class == true_label { correct_focus += 1; } } let accuracy_focus = correct_focus as f64 / n_use as f64; @@ -1500,10 +1744,16 @@ mod tests { for oct in 0..n_oct { for bi in 0..base_dim { let dim = oct * base_dim + ((bi * golden_step) % base_dim); - if dim < fd { sum[bi] += v[dim]; cnt[bi] += 1; } + if dim < fd { + sum[bi] += v[dim]; + cnt[bi] += 1; + } } } - sum.iter().zip(&cnt).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&cnt) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }; let compressed_arch: Vec> = focus_archetypes.iter().map(|a| compress(a)).collect(); @@ -1514,12 +1764,23 @@ mod tests { let mut best_class = 0; let mut best_dist = f64::MAX; for c in 0..n_classes { - if counts[c] == 0 { continue; } - let dist: f64 = compressed_feat[i].iter().zip(&compressed_arch[c]) - .map(|(a, b)| (a - b) * (a - b)).sum::().sqrt(); - if dist < best_dist { best_dist = dist; best_class = c; } + if counts[c] == 0 { + continue; + } + let dist: f64 = compressed_feat[i] + .iter() + .zip(&compressed_arch[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + if dist < best_dist { + best_dist = dist; + best_class = c; + } + } + if best_class == true_label { + correct_compressed += 1; } - if best_class == true_label { correct_compressed += 1; } } let accuracy_compressed = correct_compressed as f64 / n_use as f64; @@ -1544,10 +1805,13 @@ mod tests { // Multiple scan strategies with NARS evidence revision. // Each scan is independent evidence. Revision increases confidence. // Stop when confidence > threshold (elevation cascade). - + let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1556,15 +1820,17 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } let pixel_start = 12 + n * 4; - let img_w = 64usize; let img_h = 64usize; let ch = 3usize; + let img_w = 64usize; + let img_h = 64usize; + let ch = 3usize; let n_use = n.min(200); let pixel = |img: usize, r: usize, c: usize, channel: usize| -> f64 { let off = pixel_start + (img * d + r * img_w * ch + c * ch + channel) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 }; let luma = |img: usize, r: usize, c: usize| -> f64 { 0.299 * pixel(img, r, c, 0) + 0.587 * pixel(img, r, c, 1) + 0.114 * pixel(img, r, c, 2) @@ -1581,17 +1847,22 @@ mod tests { } // ── Scan strategy: extract features from a region ── - fn extract_patch(pixel_fn: &dyn Fn(usize, usize, usize) -> f64, - r_center: usize, c_center: usize, radius: usize, - img_h: usize, img_w: usize, ch: usize) -> Vec { + fn extract_patch( + pixel_fn: &dyn Fn(usize, usize, usize) -> f64, r_center: usize, c_center: usize, radius: usize, + img_h: usize, img_w: usize, ch: usize, + ) -> Vec { let mut f = Vec::new(); let r0 = r_center.saturating_sub(radius); let r1 = (r_center + radius).min(img_h); let c0 = c_center.saturating_sub(radius); let c1 = (c_center + radius).min(img_w); - for r in r0..r1 { for c in c0..c1 { for channel in 0..ch { - f.push(pixel_fn(r, c, channel)); - }}} + for r in r0..r1 { + for c in c0..c1 { + for channel in 0..ch { + f.push(pixel_fn(r, c, channel)); + } + } + } f } @@ -1611,11 +1882,11 @@ mod tests { // Build per-class archetypes for each strategy, then score let intersections = [ - ("NW patch", img_h/3, img_w/3, 4usize), - ("NE patch", img_h/3, 2*img_w/3, 4), - ("SW patch", 2*img_h/3, img_w/3, 4), - ("SE patch", 2*img_h/3, 2*img_w/3, 4), - ("Center", img_h/2, img_w/2, 6), + ("NW patch", img_h / 3, img_w / 3, 4usize), + ("NE patch", img_h / 3, 2 * img_w / 3, 4), + ("SW patch", 2 * img_h / 3, img_w / 3, 4), + ("SE patch", 2 * img_h / 3, 2 * img_w / 3, 4), + ("Center", img_h / 2, img_w / 2, 6), ]; // For each strategy, build archetypes and classify @@ -1623,23 +1894,33 @@ mod tests { for &(name, cr, cc, radius) in &intersections { // Extract features for all images - let features: Vec> = (0..n_use).map(|img| { - let p = |r: usize, c: usize, channel: usize| pixel(img, r, c, channel); - extract_patch(&p, cr, cc, radius, img_h, img_w, ch) - }).collect(); + let features: Vec> = (0..n_use) + .map(|img| { + let p = |r: usize, c: usize, channel: usize| pixel(img, r, c, channel); + extract_patch(&p, cr, cc, radius, img_h, img_w, ch) + }) + .collect(); - if features[0].is_empty() { continue; } + if features[0].is_empty() { + continue; + } let fd = features[0].len(); // Build archetypes let mut arch = vec![vec![0.0; fd]; n_classes]; let mut cnt = vec![0usize; n_classes]; for (i, &l) in labels[..n_use].iter().enumerate() { - for j in 0..fd { arch[l][j] += features[i][j]; } + for j in 0..fd { + arch[l][j] += features[i][j]; + } cnt[l] += 1; } for c in 0..n_classes { - if cnt[c] > 0 { for j in 0..fd { arch[c][j] /= cnt[c] as f64; } } + if cnt[c] > 0 { + for j in 0..fd { + arch[c][j] /= cnt[c] as f64; + } + } } // Score each image @@ -1647,11 +1928,20 @@ mod tests { let mut best_c = 0; let mut best_sim = f64::NEG_INFINITY; for c in 0..n_classes { - if cnt[c] == 0 { continue; } - let dist: f64 = features[i].iter().zip(&arch[c]) - .map(|(a, b)| (a-b)*(a-b)).sum::().sqrt(); + if cnt[c] == 0 { + continue; + } + let dist: f64 = features[i] + .iter() + .zip(&arch[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); let sim = 1.0 / (1.0 + dist); // convert distance to similarity - if sim > best_sim { best_sim = sim; best_c = c; } + if sim > best_sim { + best_sim = sim; + best_c = c; + } } strategy_scores[i].push((best_c, best_sim)); } @@ -1666,14 +1956,16 @@ mod tests { // Single strategy accuracies for (s, &(pred_class, _)) in strategy_scores[i].iter().enumerate() { - if pred_class == true_label { correct_single[s] += 1; } + if pred_class == true_label { + correct_single[s] += 1; + } } // NARS revision: accumulate weighted evidence across all strategies. // Each scan contributes its similarity as evidence weight for the class it detected. // Confidence grows with number of agreeing scans (NARS: more evidence = more confident). let mut class_evidence: Vec = vec![0.0; n_classes]; // total similarity weight - let mut class_votes: Vec = vec![0; n_classes]; // vote count + let mut class_votes: Vec = vec![0; n_classes]; // vote count for &(pred_class, similarity) in &strategy_scores[i] { class_evidence[pred_class] += similarity; @@ -1686,14 +1978,21 @@ mod tests { let mut best_c = 0; let mut best_score = f64::NEG_INFINITY; for c in 0..n_classes { - if class_votes[c] == 0 { continue; } + if class_votes[c] == 0 { + continue; + } let avg_sim = class_evidence[c] / class_votes[c] as f64; let vote_frac = class_votes[c] as f64 / total_scans; // Combined: how similar (frequency) × how many agree (confidence) let score = avg_sim * vote_frac; - if score > best_score { best_score = score; best_c = c; } + if score > best_score { + best_score = score; + best_c = c; + } + } + if best_c == true_label { + correct_revised += 1; } - if best_c == true_label { correct_revised += 1; } } let revised_accuracy = correct_revised as f64 / n_use as f64; @@ -1707,8 +2006,12 @@ mod tests { eprintln!(" {}: {:.1}% ({}/{})", name, acc * 100.0, correct_single[s], n_use); } eprintln!(); - eprintln!(" NARS-revised (all strategies combined): {:.1}% ({}/{})", - revised_accuracy * 100.0, correct_revised, n_use); + eprintln!( + " NARS-revised (all strategies combined): {:.1}% ({}/{})", + revised_accuracy * 100.0, + correct_revised, + n_use + ); eprintln!(" Random baseline: {:.1}%", 100.0 / n_classes as f64); eprintln!(); let best_single = correct_single.iter().max().unwrap(); @@ -1717,19 +2020,25 @@ mod tests { eprintln!(" Improvement over best single scan: {:.1}%", improvement * 100.0); eprintln!(" This is NARS evidence accumulation — each scan adds confidence."); - assert!(revised_accuracy > best_single_acc, + assert!( + revised_accuracy > best_single_acc, "NARS revision should improve over best single: {:.1}% vs {:.1}%", - revised_accuracy * 100.0, best_single_acc * 100.0); + revised_accuracy * 100.0, + best_single_acc * 100.0 + ); } #[test] #[ignore] fn test_hotspot_8x8_grid_bundling() { // 8×8 grid of 8×8 cells. For each 1/3 intersection, find the 4 hottest // neighboring cells (by gradient energy), bundle their features. - + let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1738,15 +2047,17 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } let pixel_start = 12 + n * 4; - let img_w = 64usize; let img_h = 64usize; let ch = 3usize; + let img_w = 64usize; + let img_h = 64usize; + let ch = 3usize; let n_use = n.min(200); let pixel = |img: usize, r: usize, c: usize, channel: usize| -> f64 { let off = pixel_start + (img * d + r * img_w * ch + c * ch + channel) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 }; let luma = |img: usize, r: usize, c: usize| -> f64 { 0.299 * pixel(img, r, c, 0) + 0.587 * pixel(img, r, c, 1) + 0.114 * pixel(img, r, c, 2) @@ -1759,9 +2070,9 @@ mod tests { // 1/3 intersections in cell coordinates let intersections_cell = [ - (grid_h / 3, grid_w / 3), // ~(2,2) - (grid_h / 3, 2 * grid_w / 3), // ~(2,5) - (2 * grid_h / 3, grid_w / 3), // ~(5,2) + (grid_h / 3, grid_w / 3), // ~(2,2) + (grid_h / 3, 2 * grid_w / 3), // ~(2,5) + (2 * grid_h / 3, grid_w / 3), // ~(5,2) (2 * grid_h / 3, 2 * grid_w / 3), // ~(5,5) ]; @@ -1780,10 +2091,10 @@ mod tests { let c0 = gc * cell_size; for r in r0..(r0 + cell_size) { for c in c0..(c0 + cell_size) { - if r > 0 && r < img_h-1 && c > 0 && c < img_w-1 { - let dx = luma(img, r, c+1) - luma(img, r, c.saturating_sub(1)); - let dy = luma(img, r+1, c) - luma(img, r.saturating_sub(1), c); - energy += (dx*dx + dy*dy).sqrt(); + if r > 0 && r < img_h - 1 && c > 0 && c < img_w - 1 { + let dx = luma(img, r, c + 1) - luma(img, r, c.saturating_sub(1)); + let dy = luma(img, r + 1, c) - luma(img, r.saturating_sub(1), c); + energy += (dx * dx + dy * dy).sqrt(); } } } @@ -1821,7 +2132,9 @@ mod tests { } } // Normalize bundle (mean of 4 cells) - for v in bundle.iter_mut() { *v /= 4.0; } + for v in bundle.iter_mut() { + *v /= 4.0; + } img_features.extend_from_slice(&bundle); } @@ -1834,11 +2147,17 @@ mod tests { let mut archetypes = vec![vec![0.0; feat_d]; n_classes]; let mut counts = vec![0usize; n_classes]; for (i, &l) in labels[..n_use].iter().enumerate() { - for j in 0..feat_d { archetypes[l][j] += features[i][j]; } + for j in 0..feat_d { + archetypes[l][j] += features[i][j]; + } counts[l] += 1; } for c in 0..n_classes { - if counts[c] > 0 { for j in 0..feat_d { archetypes[c][j] /= counts[c] as f64; } } + if counts[c] > 0 { + for j in 0..feat_d { + archetypes[c][j] /= counts[c] as f64; + } + } } let mut correct = 0usize; @@ -1846,17 +2165,29 @@ mod tests { let mut best_c = 0; let mut best_d = f64::MAX; for c in 0..n_classes { - if counts[c] == 0 { continue; } - let dist: f64 = features[i].iter().zip(&archetypes[c]) - .map(|(a, b)| (a-b)*(a-b)).sum::().sqrt(); - if dist < best_d { best_d = dist; best_c = c; } + if counts[c] == 0 { + continue; + } + let dist: f64 = features[i] + .iter() + .zip(&archetypes[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + if dist < best_d { + best_d = dist; + best_c = c; + } + } + if best_c == true_label { + correct += 1; } - if best_c == true_label { correct += 1; } } let accuracy = correct as f64 / n_use as f64; // Also: golden-step compressed - let base_dim = 17; let golden_step = 11; + let base_dim = 17; + let golden_step = 11; let compress = |v: &[f64]| -> Vec { let fd = v.len(); let n_oct = (fd + base_dim - 1) / base_dim; @@ -1865,22 +2196,41 @@ mod tests { for oct in 0..n_oct { for bi in 0..base_dim { let dim = oct * base_dim + ((bi * golden_step) % base_dim); - if dim < fd { sum[bi] += v[dim]; cnt[bi] += 1; } + if dim < fd { + sum[bi] += v[dim]; + cnt[bi] += 1; + } } } - sum.iter().zip(&cnt).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&cnt) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }; let c_arch: Vec> = archetypes.iter().map(|a| compress(a)).collect(); let c_feat: Vec> = features.iter().map(|f| compress(f)).collect(); let mut correct_c = 0; for (i, &tl) in labels[..n_use].iter().enumerate() { - let mut best_c = 0; let mut best_d = f64::MAX; + let mut best_c = 0; + let mut best_d = f64::MAX; for c in 0..n_classes { - if counts[c] == 0 { continue; } - let dist: f64 = c_feat[i].iter().zip(&c_arch[c]).map(|(a,b)|(a-b)*(a-b)).sum::().sqrt(); - if dist < best_d { best_d = dist; best_c = c; } + if counts[c] == 0 { + continue; + } + let dist: f64 = c_feat[i] + .iter() + .zip(&c_arch[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + if dist < best_d { + best_d = dist; + best_c = c; + } + } + if best_c == tl { + correct_c += 1; } - if best_c == tl { correct_c += 1; } } let acc_c = correct_c as f64 / n_use as f64; @@ -1909,7 +2259,10 @@ mod tests { let bytes = match std::fs::read("/tmp/tiny_imagenet_labeled.bin") { Ok(b) => b, - Err(_) => { eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); return; } + Err(_) => { + eprintln!("SKIP: /tmp/tiny_imagenet_labeled.bin not found"); + return; + } }; let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; @@ -1918,15 +2271,17 @@ mod tests { let mut labels = Vec::with_capacity(n); for i in 0..n { let off = 12 + i * 4; - labels.push(u32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as usize); + labels.push(u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as usize); } let pixel_start = 12 + n * 4; - let img_w = 64usize; let img_h = 64usize; let ch = 3usize; + let img_w = 64usize; + let img_h = 64usize; + let ch = 3usize; let n_use = n.min(200); let pixel = |img: usize, r: usize, c: usize, channel: usize| -> f64 { let off = pixel_start + (img * d + r * img_w * ch + c * ch + channel) * 4; - f32::from_le_bytes([bytes[off], bytes[off+1], bytes[off+2], bytes[off+3]]) as f64 + f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]) as f64 }; let luma = |img: usize, r: usize, c: usize| -> f64 { 0.299 * pixel(img, r, c, 0) + 0.587 * pixel(img, r, c, 1) + 0.114 * pixel(img, r, c, 2) @@ -1934,42 +2289,61 @@ mod tests { // ── LEAF: full centroid focus patch (432D, highest resolution) ── let focus_radius = 6usize; - let intersections = [(img_h/3, img_w/3), (img_h/3, 2*img_w/3), - (2*img_h/3, img_w/3), (2*img_h/3, 2*img_w/3)]; - - let leaf_features: Vec> = (0..n_use).map(|img| { - // Find highest-energy intersection - let mut best_r = img_h / 2; - let mut best_c = img_w / 2; - let mut best_energy = 0.0f64; - for &(ir, ic) in &intersections { - let mut energy = 0.0; - let r0 = ir.saturating_sub(8); let r1 = (ir+8).min(img_h); - let c0 = ic.saturating_sub(8); let c1 = (ic+8).min(img_w); - for r in r0..r1 { for c in c0..c1 { - if r > 0 && r < img_h-1 && c > 0 && c < img_w-1 { - let dx = luma(img, r, c+1) - luma(img, r, c.saturating_sub(1)); - let dy = luma(img, r+1, c) - luma(img, r.saturating_sub(1), c); - energy += (dx*dx + dy*dy).sqrt(); + let intersections = [ + (img_h / 3, img_w / 3), + (img_h / 3, 2 * img_w / 3), + (2 * img_h / 3, img_w / 3), + (2 * img_h / 3, 2 * img_w / 3), + ]; + + let leaf_features: Vec> = (0..n_use) + .map(|img| { + // Find highest-energy intersection + let mut best_r = img_h / 2; + let mut best_c = img_w / 2; + let mut best_energy = 0.0f64; + for &(ir, ic) in &intersections { + let mut energy = 0.0; + let r0 = ir.saturating_sub(8); + let r1 = (ir + 8).min(img_h); + let c0 = ic.saturating_sub(8); + let c1 = (ic + 8).min(img_w); + for r in r0..r1 { + for c in c0..c1 { + if r > 0 && r < img_h - 1 && c > 0 && c < img_w - 1 { + let dx = luma(img, r, c + 1) - luma(img, r, c.saturating_sub(1)); + let dy = luma(img, r + 1, c) - luma(img, r.saturating_sub(1), c); + energy += (dx * dx + dy * dy).sqrt(); + } + } } - }} - if energy > best_energy { best_energy = energy; best_r = ir; best_c = ic; } - } - // Extract patch - let mut f = Vec::with_capacity(432); - let r0 = best_r.saturating_sub(focus_radius); - let r1 = (best_r + focus_radius).min(img_h); - let c0 = best_c.saturating_sub(focus_radius); - let c1 = (best_c + focus_radius).min(img_w); - for r in r0..r1 { for c in c0..c1 { for channel in 0..ch { - f.push(pixel(img, r, c, channel)); - }}} - f.resize(432, 0.0); - f - }).collect(); + if energy > best_energy { + best_energy = energy; + best_r = ir; + best_c = ic; + } + } + // Extract patch + let mut f = Vec::with_capacity(432); + let r0 = best_r.saturating_sub(focus_radius); + let r1 = (best_r + focus_radius).min(img_h); + let c0 = best_c.saturating_sub(focus_radius); + let c1 = (best_c + focus_radius).min(img_w); + for r in r0..r1 { + for c in c0..c1 { + for channel in 0..ch { + f.push(pixel(img, r, c, channel)); + } + } + } + f.resize(432, 0.0); + f + }) + .collect(); // ── BRANCH: golden-step compress (432D → 17D = 34 bytes) ── - let base_dim = 17; let golden_step = 11; + let base_dim = 17; + let golden_step = 11; let compress17 = |v: &[f64]| -> Vec { let fd = v.len(); let n_oct = (fd + base_dim - 1) / base_dim; @@ -1978,28 +2352,45 @@ mod tests { for oct in 0..n_oct { for bi in 0..base_dim { let dim = oct * base_dim + ((bi * golden_step) % base_dim); - if dim < fd { sum[bi] += v[dim]; cnt[bi] += 1; } + if dim < fd { + sum[bi] += v[dim]; + cnt[bi] += 1; + } } } - sum.iter().zip(&cnt).map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }).collect() + sum.iter() + .zip(&cnt) + .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 }) + .collect() }; let branch_features: Vec> = leaf_features.iter().map(|f| compress17(f)).collect(); // ── HIP: quantize to i16 (17D × 2 bytes = 34 bytes, same size but integer) ── - let hip_features: Vec> = branch_features.iter().map(|f| { - f.iter().map(|&v| ((v * 1000.0).round().clamp(-32768.0, 32767.0) as i16) as f64 / 1000.0).collect() - }).collect(); + let hip_features: Vec> = branch_features + .iter() + .map(|f| { + f.iter() + .map(|&v| ((v * 1000.0).round().clamp(-32768.0, 32767.0) as i16) as f64 / 1000.0) + .collect() + }) + .collect(); // ── HEEL: scent byte — reduce to single energy + top category vote ── // Scent = which of the 17 dimensions has highest absolute value - let heel_features: Vec> = branch_features.iter().map(|f| { - let max_dim = f.iter().enumerate() - .max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap()) - .map(|(i, _)| i).unwrap_or(0); - let energy: f64 = f.iter().map(|v| v * v).sum::().sqrt(); - // 2D scent: dominant dimension + energy level - vec![max_dim as f64, energy] - }).collect(); + let heel_features: Vec> = branch_features + .iter() + .map(|f| { + let max_dim = f + .iter() + .enumerate() + .max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap()) + .map(|(i, _)| i) + .unwrap_or(0); + let energy: f64 = f.iter().map(|v| v * v).sum::().sqrt(); + // 2D scent: dominant dimension + energy level + vec![max_dim as f64, energy] + }) + .collect(); // ── Classify at each level ── fn classify(features: &[Vec], labels: &[usize], n_classes: usize) -> (f64, usize) { @@ -2008,22 +2399,40 @@ mod tests { let mut arch = vec![vec![0.0; fd]; n_classes]; let mut cnt = vec![0usize; n_classes]; for (i, &l) in labels.iter().enumerate() { - for j in 0..fd { arch[l][j] += features[i][j]; } + for j in 0..fd { + arch[l][j] += features[i][j]; + } cnt[l] += 1; } for c in 0..n_classes { - if cnt[c] > 0 { for j in 0..fd { arch[c][j] /= cnt[c] as f64; } } + if cnt[c] > 0 { + for j in 0..fd { + arch[c][j] /= cnt[c] as f64; + } + } } let mut correct = 0; for (i, &tl) in labels.iter().enumerate() { - let mut best_c = 0; let mut best_d = f64::MAX; + let mut best_c = 0; + let mut best_d = f64::MAX; for c in 0..n_classes { - if cnt[c] == 0 { continue; } - let dist: f64 = features[i].iter().zip(&arch[c]) - .map(|(a, b)| (a-b)*(a-b)).sum::().sqrt(); - if dist < best_d { best_d = dist; best_c = c; } + if cnt[c] == 0 { + continue; + } + let dist: f64 = features[i] + .iter() + .zip(&arch[c]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + if dist < best_d { + best_d = dist; + best_c = c; + } + } + if best_c == tl { + correct += 1; } - if best_c == tl { correct += 1; } } (correct as f64 / n as f64, correct) } @@ -2032,11 +2441,17 @@ mod tests { fn pairwise_dists(feats: &[Vec]) -> Vec { let n = feats.len().min(50); // limit for speed let mut d = Vec::new(); - for i in 0..n { for j in (i+1)..n { - let dist: f64 = feats[i].iter().zip(&feats[j]) - .map(|(a,b)| (a-b)*(a-b)).sum::().sqrt(); - d.push(dist); - }} + for i in 0..n { + for j in (i + 1)..n { + let dist: f64 = feats[i] + .iter() + .zip(&feats[j]) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + d.push(dist); + } + } d } fn spearman(a: &[f64], b: &[f64]) -> f64 { @@ -2044,7 +2459,9 @@ mod tests { let mut idx: Vec<(usize, f64)> = v.iter().copied().enumerate().collect(); idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let mut r = vec![0.0; v.len()]; - for (rank, (i, _)) in idx.into_iter().enumerate() { r[i] = rank as f64; } + for (rank, (i, _)) in idx.into_iter().enumerate() { + r[i] = rank as f64; + } r } let (ra, rb) = (ranks(a), ranks(b)); @@ -2053,9 +2470,15 @@ mod tests { let (mut cov, mut va, mut vb) = (0.0, 0.0, 0.0); for i in 0..a.len() { let (da, db) = (ra[i] - ma, rb[i] - mb); - cov += da * db; va += da * da; vb += db * db; + cov += da * db; + va += da * da; + vb += db * db; + } + if va < 1e-10 || vb < 1e-10 { + 0.0 + } else { + cov / (va * vb).sqrt() } - if va < 1e-10 || vb < 1e-10 { 0.0 } else { cov / (va * vb).sqrt() } } let leaf_dists = pairwise_dists(&leaf_features); @@ -2073,21 +2496,45 @@ mod tests { eprintln!(); eprintln!(" Level Dims Bytes Accuracy ρ vs LEAF ρ/byte"); eprintln!(" ───────── ───── ───── ───────── ───────── ──────"); - eprintln!(" LEAF 432D 864B {:.1}% ({}/{}) 1.0000 {:.6}", - acc_leaf*100.0, cor_leaf, n_use, 1.0/864.0); - eprintln!(" BRANCH 17D 34B {:.1}% ({}/{}) {:.4} {:.6}", - acc_branch*100.0, cor_branch, n_use, rho_branch, rho_branch/34.0); - eprintln!(" HIP 17D 34B {:.1}% ({}/{}) {:.4} {:.6}", - acc_hip*100.0, cor_hip, n_use, rho_hip, rho_hip/34.0); - eprintln!(" HEEL 2D 2B {:.1}% ({}/{}) {:.4} {:.6}", - acc_heel*100.0, cor_heel, n_use, rho_heel, rho_heel/2.0); - eprintln!(" Random — 0B {:.1}%", 100.0/n_classes as f64); + eprintln!( + " LEAF 432D 864B {:.1}% ({}/{}) 1.0000 {:.6}", + acc_leaf * 100.0, + cor_leaf, + n_use, + 1.0 / 864.0 + ); + eprintln!( + " BRANCH 17D 34B {:.1}% ({}/{}) {:.4} {:.6}", + acc_branch * 100.0, + cor_branch, + n_use, + rho_branch, + rho_branch / 34.0 + ); + eprintln!( + " HIP 17D 34B {:.1}% ({}/{}) {:.4} {:.6}", + acc_hip * 100.0, + cor_hip, + n_use, + rho_hip, + rho_hip / 34.0 + ); + eprintln!( + " HEEL 2D 2B {:.1}% ({}/{}) {:.4} {:.6}", + acc_heel * 100.0, + cor_heel, + n_use, + rho_heel, + rho_heel / 2.0 + ); + eprintln!(" Random — 0B {:.1}%", 100.0 / n_classes as f64); eprintln!(); eprintln!(" Cascade rejection simulation:"); - eprintln!(" HEEL rejects: {:.0}% of wrong classes (scent screening)", - (1.0 - 1.0/n_classes as f64) * (1.0 - acc_heel) * 100.0); - eprintln!(" After HEEL→HIP: {:.0}% remaining need full BRANCH check", - (1.0 - acc_hip) * 100.0); + eprintln!( + " HEEL rejects: {:.0}% of wrong classes (scent screening)", + (1.0 - 1.0 / n_classes as f64) * (1.0 - acc_heel) * 100.0 + ); + eprintln!(" After HEEL→HIP: {:.0}% remaining need full BRANCH check", (1.0 - acc_hip) * 100.0); assert!(acc_leaf > acc_branch, "LEAF should beat BRANCH"); assert!(acc_branch >= acc_heel, "BRANCH should beat or match HEEL"); diff --git a/src/hpc/vnni_gemm.rs b/src/hpc/vnni_gemm.rs index 00fc6bbc..b156a8b2 100644 --- a/src/hpc/vnni_gemm.rs +++ b/src/hpc/vnni_gemm.rs @@ -43,14 +43,7 @@ use super::simd_caps::simd_caps; /// # Panics /// /// Panics if the slice lengths are inconsistent with the given dimensions. -pub fn int8_gemm_vnni( - a: &[u8], - b: &[i8], - c: &mut [i32], - m: usize, - n: usize, - k: usize, -) { +pub fn int8_gemm_vnni(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { assert!(a.len() >= m * k, "a.len()={} < m*k={}", a.len(), m * k); assert!(b.len() >= k * n, "b.len()={} < k*n={}", b.len(), k * n); assert!(c.len() >= m * n, "c.len()={} < m*n={}", c.len(), m * n); @@ -98,14 +91,7 @@ pub fn has_vnni() -> bool { /// contains 4 bytes from consecutive rows. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f,avx512vnni,avx512bw")] -unsafe fn int8_gemm_vnni_avx512( - a: &[u8], - b: &[i8], - c: &mut [i32], - m: usize, - n: usize, - k: usize, -) { +unsafe fn int8_gemm_vnni_avx512(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { use core::arch::x86_64::*; // Zero output @@ -229,18 +215,8 @@ mod tests { let n = 4; let k = 4; // Simple identity-like test - let a: Vec = vec![ - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - ]; - let b: Vec = vec![ - 1, 0, 0, 0, - 0, 1, 0, 0, - 0, 0, 1, 0, - 0, 0, 0, 1, - ]; + let a: Vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + let b: Vec = vec![1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]; let expected = scalar_gemm(&a, &b, m, n, k); let mut c = vec![0i32; m * n]; int8_gemm_vnni(&a, &b, &mut c, m, n, k); @@ -252,18 +228,8 @@ mod tests { let m = 4; let n = 4; let k = 4; - let a: Vec = vec![ - 128, 64, 32, 16, - 255, 0, 128, 64, - 1, 2, 3, 4, - 200, 100, 50, 25, - ]; - let b: Vec = vec![ - 1, -1, 2, -2, - 3, -3, 4, -4, - 5, -5, 6, -6, - 7, -7, 8, -8, - ]; + let a: Vec = vec![128, 64, 32, 16, 255, 0, 128, 64, 1, 2, 3, 4, 200, 100, 50, 25]; + let b: Vec = vec![1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8]; let expected = scalar_gemm(&a, &b, m, n, k); let mut c = vec![0i32; m * n]; int8_gemm_vnni(&a, &b, &mut c, m, n, k); @@ -276,7 +242,9 @@ mod tests { let n = 16; let k = 16; let a: Vec = (0..m * k).map(|i| (i % 251) as u8).collect(); - let b: Vec = (0..k * n).map(|i| ((i % 127) as i8).wrapping_sub(63)).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i % 127) as i8).wrapping_sub(63)) + .collect(); let expected = scalar_gemm(&a, &b, m, n, k); let mut c = vec![0i32; m * n]; int8_gemm_vnni(&a, &b, &mut c, m, n, k); diff --git a/src/hpc/vsa.rs b/src/hpc/vsa.rs index 282342a0..0515ef1e 100644 --- a/src/hpc/vsa.rs +++ b/src/hpc/vsa.rs @@ -180,9 +180,7 @@ impl VsaVector { #[inline] pub fn as_bytes(&self) -> &[u8] { // SAFETY: [u64; N] is contiguous, u64→u8 cast is alignment-safe. - unsafe { - std::slice::from_raw_parts(self.words.as_ptr() as *const u8, VSA_WORDS * 8) - } + unsafe { std::slice::from_raw_parts(self.words.as_ptr() as *const u8, VSA_WORDS * 8) } } /// Population count: number of set bits (out of 16,384). @@ -537,12 +535,7 @@ mod tests { let bundled = vsa_bundle(&[a.clone(), a.clone(), a.clone(), b.clone()]); let sim_a = vsa_similarity(&bundled, &a); let sim_b = vsa_similarity(&bundled, &b); - assert!( - sim_a > sim_b, - "bundled should be closer to a (sim_a={}, sim_b={})", - sim_a, - sim_b - ); + assert!(sim_a > sim_b, "bundled should be closer to a (sim_a={}, sim_b={})", sim_a, sim_b); } #[test] @@ -575,11 +568,7 @@ mod tests { fn test_similarity_self_one() { let a = VsaVector::random(42); let sim = vsa_similarity(&a, &a); - assert!( - (sim - 1.0).abs() < f32::EPSILON, - "self-similarity should be 1.0, got {}", - sim - ); + assert!((sim - 1.0).abs() < f32::EPSILON, "self-similarity should be 1.0, got {}", sim); } #[test] @@ -593,11 +582,7 @@ mod tests { complement.words[VSA_WORDS - 1] &= TAIL_MASK; let sim = vsa_similarity(&a, &complement); - assert!( - sim.abs() < 0.01, - "complement similarity should be ~0.0, got {}", - sim - ); + assert!(sim.abs() < 0.01, "complement similarity should be ~0.0, got {}", sim); } #[test] @@ -645,10 +630,7 @@ mod tests { let codebook = vec![other1, other2, entry.clone()]; let found = vsa_clean(&noisy, &codebook).unwrap(); - assert_eq!( - found, &entry, - "clean should recover the closest codebook entry" - ); + assert_eq!(found, &entry, "clean should recover the closest codebook entry"); } #[test] diff --git a/src/hpc/zeck.rs b/src/hpc/zeck.rs index f2d6358e..0fbd4d10 100644 --- a/src/hpc/zeck.rs +++ b/src/hpc/zeck.rs @@ -90,10 +90,7 @@ pub fn zeckf64_from_distances(ds: u32, dp: u32, d_o: u32) -> u64 { /// Compute ZeckF64 encoding from raw fingerprint byte slices. /// /// Each triple is `(subject, predicate, object)` as `&[u8]` (16384-bit / 2048 bytes). -pub fn zeckf64( - a: (&[u8], &[u8], &[u8]), - b: (&[u8], &[u8], &[u8]), -) -> u64 { +pub fn zeckf64(a: (&[u8], &[u8], &[u8]), b: (&[u8], &[u8], &[u8])) -> u64 { let ds = hamming_distance_raw(a.0, b.0) as u32; let dp = hamming_distance_raw(a.1, b.1) as u32; let d_o = hamming_distance_raw(a.2, b.2) as u32; @@ -120,7 +117,11 @@ pub fn resolution(edge: u64, byte_n: u8) -> u8 { /// Set the sign (causality direction) bit. #[inline] pub fn set_sign(edge: u64, sign: bool) -> u64 { - if sign { edge | (1u64 << 7) } else { edge & !(1u64 << 7) } + if sign { + edge | (1u64 << 7) + } else { + edge & !(1u64 << 7) + } } /// Read the sign bit. @@ -185,11 +186,19 @@ pub fn is_legal_scent(byte0: u8) -> bool { let spo = (byte0 >> 6) & 1; // Lattice: pair bit implies both individual bits - if sp == 1 && (s == 0 || p == 0) { return false; } - if so == 1 && (s == 0 || o == 0) { return false; } - if po == 1 && (p == 0 || o == 0) { return false; } + if sp == 1 && (s == 0 || p == 0) { + return false; + } + if so == 1 && (s == 0 || o == 0) { + return false; + } + if po == 1 && (p == 0 || o == 0) { + return false; + } // Triple implies all three pairs - if spo == 1 && (sp == 0 || so == 0 || po == 0) { return false; } + if spo == 1 && (sp == 0 || so == 0 || po == 0) { + return false; + } true } @@ -206,7 +215,10 @@ pub fn zeckf64_batch(query: u64, edges: &[u64]) -> Vec { /// Batch scent-only distance: compute scent distance from `query` to each edge. pub fn zeckf64_scent_batch(query: u64, edges: &[u64]) -> Vec { - edges.iter().map(|&e| zeckf64_scent_distance(query, e)).collect() + edges + .iter() + .map(|&e| zeckf64_scent_distance(query, e)) + .collect() } /// Top-k nearest edges by ZeckF64 distance. @@ -254,10 +266,7 @@ pub fn zeckf64_scent_top_k(query: u64, edges: &[u64], k: usize) -> (Vec, /// `database` = flat buffer of concatenated S+P+O planes (6144 bytes per row). /// Returns a Vec of `num_rows` u64 ZeckF64 values. pub fn zeckf64_encode_batch( - query: (&[u8], &[u8], &[u8]), - database: &[u8], - num_rows: usize, - plane_bytes: usize, + query: (&[u8], &[u8], &[u8]), database: &[u8], num_rows: usize, plane_bytes: usize, ) -> Vec { let row_bytes = plane_bytes * 3; (0..num_rows) @@ -352,9 +361,9 @@ mod tests { fn test_zeckf64_batch() { let query = zeckf64_from_distances(1000, 2000, 3000); let edges = vec![ - zeckf64_from_distances(1000, 2000, 3000), // identical + zeckf64_from_distances(1000, 2000, 3000), // identical zeckf64_from_distances(5000, 8000, 10000), // far - zeckf64_from_distances(1100, 2100, 3100), // close + zeckf64_from_distances(1100, 2100, 3100), // close ]; let dists = zeckf64_batch(query, &edges); assert_eq!(dists[0], 0); // identical → 0 @@ -366,9 +375,9 @@ mod tests { let query = zeckf64_from_distances(1000, 2000, 3000); let edges = vec![ zeckf64_from_distances(5000, 8000, 10000), // far - zeckf64_from_distances(1000, 2000, 3000), // identical - zeckf64_from_distances(1100, 2100, 3100), // close - zeckf64_from_distances(8000, 8000, 8000), // very far + zeckf64_from_distances(1000, 2000, 3000), // identical + zeckf64_from_distances(1100, 2100, 3100), // close + zeckf64_from_distances(8000, 8000, 8000), // very far ]; let (indices, dists) = zeckf64_top_k(query, &edges, 2); assert_eq!(indices.len(), 2); @@ -381,8 +390,8 @@ mod tests { let query = zeckf64_from_distances(0, 0, 0); // all close let edges = vec![ zeckf64_from_distances(D_MAX, D_MAX, D_MAX), // none close - zeckf64_from_distances(0, 0, 0), // all close (match) - zeckf64_from_distances(100, 100, 100), // all close (close) + zeckf64_from_distances(0, 0, 0), // all close (match) + zeckf64_from_distances(100, 100, 100), // all close (close) ]; let (indices, _) = zeckf64_scent_top_k(query, &edges, 2); // Both edges 1 and 2 have all-close scent bytes diff --git a/src/impl_1d.rs b/src/impl_1d.rs index bd34ba2c..c63cbcd5 100644 --- a/src/impl_1d.rs +++ b/src/impl_1d.rs @@ -15,11 +15,11 @@ use crate::imp_prelude::*; use crate::low_level_util::AbortIfPanic; /// # Methods For 1-D Arrays -impl ArrayRef -{ +impl ArrayRef { /// Return an vector with the elements of the one-dimensional array. pub fn to_vec(&self) -> Vec - where A: Clone + where + A: Clone, { if let Some(slc) = self.as_slice() { slc.to_vec() @@ -30,8 +30,7 @@ impl ArrayRef /// Rotate the elements of the array by 1 element towards the front; /// the former first element becomes the last. - pub(crate) fn rotate1_front(&mut self) - { + pub(crate) fn rotate1_front(&mut self) { // use swapping to keep all elements initialized (as required by owned storage) let mut lane_iter = self.iter_mut(); let mut dst = if let Some(dst) = lane_iter.next() { dst } else { return }; diff --git a/src/impl_2d.rs b/src/impl_2d.rs index b6379e67..6b04121f 100644 --- a/src/impl_2d.rs +++ b/src/impl_2d.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods For 2-D Arrays -impl ArrayRef -{ +impl ArrayRef { /// Return an array view of row `index`. /// /// **Panics** if `index` is out of bounds. @@ -22,8 +21,7 @@ impl ArrayRef /// assert_eq!(array.row(0), array![1., 2.]); /// ``` #[track_caller] - pub fn row(&self, index: Ix) -> ArrayView1<'_, A> - { + pub fn row(&self, index: Ix) -> ArrayView1<'_, A> { self.index_axis(Axis(0), index) } @@ -38,14 +36,12 @@ impl ArrayRef /// assert_eq!(array, array![[1., 5.], [3., 4.]]); /// ``` #[track_caller] - pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - { + pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> { self.index_axis_mut(Axis(0), index) } } -impl LayoutRef -{ +impl LayoutRef { /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. /// /// ``` @@ -63,14 +59,12 @@ impl LayoutRef /// // get length of any particular axis with .len_of() /// assert_eq!(m, array.len_of(Axis(0))); /// ``` - pub fn nrows(&self) -> usize - { + pub fn nrows(&self) -> usize { self.len_of(Axis(0)) } } -impl ArrayRef -{ +impl ArrayRef { /// Return an array view of column `index`. /// /// **Panics** if `index` is out of bounds. @@ -81,8 +75,7 @@ impl ArrayRef /// assert_eq!(array.column(0), array![1., 3.]); /// ``` #[track_caller] - pub fn column(&self, index: Ix) -> ArrayView1<'_, A> - { + pub fn column(&self, index: Ix) -> ArrayView1<'_, A> { self.index_axis(Axis(1), index) } @@ -97,14 +90,12 @@ impl ArrayRef /// assert_eq!(array, array![[1., 2.], [5., 4.]]); /// ``` #[track_caller] - pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - { + pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> { self.index_axis_mut(Axis(1), index) } } -impl LayoutRef -{ +impl LayoutRef { /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. /// /// ``` @@ -122,8 +113,7 @@ impl LayoutRef /// // get length of any particular axis with .len_of() /// assert_eq!(n, array.len_of(Axis(1))); /// ``` - pub fn ncols(&self) -> usize - { + pub fn ncols(&self) -> usize { self.len_of(Axis(1)) } @@ -142,15 +132,13 @@ impl LayoutRef /// let array = array![[1., 2., 5.], [3., 4., 6.]]; /// assert!(!array.is_square()); /// ``` - pub fn is_square(&self) -> bool - { + pub fn is_square(&self) -> bool { let (m, n) = self.dim(); m == n } } -impl ArrayBase -{ +impl ArrayBase { /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. /// /// ``` @@ -168,8 +156,7 @@ impl ArrayBase /// // get length of any particular axis with .len_of() /// assert_eq!(m, array.len_of(Axis(0))); /// ``` - pub fn nrows(&self) -> usize - { + pub fn nrows(&self) -> usize { self.as_layout_ref().nrows() } @@ -190,8 +177,7 @@ impl ArrayBase /// // get length of any particular axis with .len_of() /// assert_eq!(n, array.len_of(Axis(1))); /// ``` - pub fn ncols(&self) -> usize - { + pub fn ncols(&self) -> usize { self.as_layout_ref().ncols() } @@ -210,8 +196,7 @@ impl ArrayBase /// let array = array![[1., 2., 5.], [3., 4., 6.]]; /// assert!(!array.is_square()); /// ``` - pub fn is_square(&self) -> bool - { + pub fn is_square(&self) -> bool { self.as_layout_ref().is_square() } } diff --git a/src/impl_arc_array.rs b/src/impl_arc_array.rs index 619ae250..bdeb0031 100644 --- a/src/impl_arc_array.rs +++ b/src/impl_arc_array.rs @@ -18,12 +18,12 @@ use portable_atomic_util::Arc; /// /// ***See also all methods for [`ArrayBase`]*** impl ArcArray -where D: Dimension +where + D: Dimension, { /// Returns `true` iff the inner `Arc` is not shared. /// If you want to ensure the `Arc` is not concurrently cloned, you need to provide a `&mut self` to this function. - pub fn is_unique(&self) -> bool - { + pub fn is_unique(&self) -> bool { // Only strong pointers are used in this crate. Arc::strong_count(&self.data.0) == 1 } diff --git a/src/impl_clone.rs b/src/impl_clone.rs index bef783bd..c61dc741 100644 --- a/src/impl_clone.rs +++ b/src/impl_clone.rs @@ -10,10 +10,8 @@ use crate::imp_prelude::*; use crate::ArrayPartsSized; use crate::RawDataClone; -impl Clone for ArrayBase -{ - fn clone(&self) -> ArrayBase - { +impl Clone for ArrayBase { + fn clone(&self) -> ArrayBase { // safe because `clone_with_ptr` promises to provide equivalent data and ptr unsafe { let (data, ptr) = self.data.clone_with_ptr(self.parts.ptr); @@ -27,8 +25,7 @@ impl Clone for ArrayBase /// `Array` implements `.clone_from()` to reuse an array's existing /// allocation. Semantically equivalent to `*self = other.clone()`, but /// potentially more efficient. - fn clone_from(&mut self, other: &Self) - { + fn clone_from(&mut self, other: &Self) { unsafe { self.parts.ptr = self.data.clone_from_with_ptr(&other.data, other.parts.ptr); self.parts.dim.clone_from(&other.parts.dim); diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 7f71cca5..c966caf4 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -44,7 +44,8 @@ use rawpointer::PointerExt; /// /// ## Constructor methods for one-dimensional arrays. impl ArrayBase -where S: DataOwned +where + S: DataOwned, { /// Create a one-dimensional array from a vector (no copying needed). /// @@ -55,8 +56,7 @@ where S: DataOwned /// /// let array = Array::from_vec(vec![1., 2., 3., 4.]); /// ``` - pub fn from_vec(v: Vec) -> Self - { + pub fn from_vec(v: Vec) -> Self { if mem::size_of::() == 0 { assert!(v.len() <= isize::MAX as usize, "Length must fit in `isize`.",); } @@ -73,8 +73,7 @@ where S: DataOwned /// let array = Array::from_iter(0..10); /// ``` #[allow(clippy::should_implement_trait)] - pub fn from_iter>(iterable: I) -> Self - { + pub fn from_iter>(iterable: I) -> Self { Self::from_vec(iterable.into_iter().collect()) } @@ -117,7 +116,8 @@ where S: DataOwned /// ``` #[cfg(feature = "std")] pub fn range(start: A, end: A, step: A) -> Self - where A: Float + where + A: Float, { Self::from(to_vec(linspace::range(start, end, step))) } @@ -181,7 +181,8 @@ where S: DataOwned /// ``` #[cfg(feature = "std")] pub fn geomspace(start: A, end: A, n: usize) -> Option - where A: Float + where + A: Float, { Some(Self::from(to_vec(geomspace::geomspace(start, end, n)?))) } @@ -189,7 +190,8 @@ where S: DataOwned /// ## Constructor methods for two-dimensional arrays. impl ArrayBase -where S: DataOwned +where + S: DataOwned, { /// Create an identity matrix of size `n` (square 2D array). /// @@ -471,14 +473,14 @@ where /// ); /// ``` pub fn from_shape_vec(shape: Sh, v: Vec) -> Result - where Sh: Into> + where + Sh: Into>, { // eliminate the type parameter Sh as soon as possible Self::from_shape_vec_impl(shape.into(), v) } - fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result - { + fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result { let dim = shape.dim; let is_custom = shape.strides.is_custom(); dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?; @@ -514,7 +516,8 @@ where /// 5. The strides must not allow any element to be referenced by two different /// indices. pub unsafe fn from_shape_vec_unchecked(shape: Sh, v: Vec) -> Self - where Sh: Into> + where + Sh: Into>, { let shape = shape.into(); let dim = shape.dim; @@ -522,8 +525,7 @@ where Self::from_vec_dim_stride_unchecked(dim, strides, v) } - unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self - { + unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self { // debug check for issues that indicates wrong use of this constructor debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok()); @@ -596,7 +598,8 @@ where /// # let _ = shift_by_two; /// ``` pub fn uninit(shape: Sh) -> ArrayBase - where Sh: ShapeBuilder + where + Sh: ShapeBuilder, { unsafe { let shape = shape.into_shape_with_order(); diff --git a/src/impl_cow.rs b/src/impl_cow.rs index 1a28996d..3cda78e3 100644 --- a/src/impl_cow.rs +++ b/src/impl_cow.rs @@ -12,26 +12,25 @@ use crate::imp_prelude::*; /// /// ***See also all methods for [`ArrayBase`]*** impl CowArray<'_, A, D> -where D: Dimension +where + D: Dimension, { /// Returns `true` iff the array is the view (borrowed) variant. - pub fn is_view(&self) -> bool - { + pub fn is_view(&self) -> bool { self.data.is_view() } /// Returns `true` iff the array is the owned variant. - pub fn is_owned(&self) -> bool - { + pub fn is_owned(&self) -> bool { self.data.is_owned() } } impl<'a, A, D> From> for CowArray<'a, A, D> -where D: Dimension +where + D: Dimension, { - fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> - { + fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> { // safe because equivalent data unsafe { ArrayBase::from_data_ptr(CowRepr::View(view.data), view.parts.ptr) @@ -41,10 +40,10 @@ where D: Dimension } impl<'a, A, D> From> for CowArray<'a, A, D> -where D: Dimension +where + D: Dimension, { - fn from(array: Array) -> CowArray<'a, A, D> - { + fn from(array: Array) -> CowArray<'a, A, D> { // safe because equivalent data unsafe { ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.parts.ptr) @@ -54,7 +53,8 @@ where D: Dimension } impl<'a, A, Slice: ?Sized> From<&'a Slice> for CowArray<'a, A, Ix1> -where Slice: AsRef<[A]> +where + Slice: AsRef<[A]>, { /// Create a one-dimensional clone-on-write view of the data in `slice`. /// @@ -67,8 +67,7 @@ where Slice: AsRef<[A]> /// assert!(array.is_view()); /// assert_eq!(array, array![1., 2., 3., 4.]); /// ``` - fn from(slice: &'a Slice) -> Self - { + fn from(slice: &'a Slice) -> Self { Self::from(ArrayView1::from(slice)) } } @@ -79,8 +78,7 @@ where D: Dimension, { /// Create a read-only clone-on-write view of the array. - fn from(array: &'a ArrayBase) -> Self - { + fn from(array: &'a ArrayBase) -> Self { Self::from(array.view()) } } diff --git a/src/impl_dyn.rs b/src/impl_dyn.rs index 404f3d4b..03010fb7 100644 --- a/src/impl_dyn.rs +++ b/src/impl_dyn.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods for Dynamic-Dimensional Arrays -impl LayoutRef -{ +impl LayoutRef { /// Insert new array axis of length 1 at `axis`, modifying the shape and /// strides in-place. /// @@ -28,8 +27,7 @@ impl LayoutRef /// assert_eq!(a.shape(), &[2, 1, 3]); /// ``` #[track_caller] - pub fn insert_axis_inplace(&mut self, axis: Axis) - { + pub fn insert_axis_inplace(&mut self, axis: Axis) { assert!(axis.index() <= self.ndim()); self.0.dim = self._dim().insert_axis(axis); self.0.strides = self._strides().insert_axis(axis); @@ -51,16 +49,14 @@ impl LayoutRef /// assert_eq!(a.shape(), &[2]); /// ``` #[track_caller] - pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) - { + pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) { self.collapse_axis(axis, index); self.0.dim = self._dim().remove_axis(axis); self.0.strides = self._strides().remove_axis(axis); } } -impl ArrayBase -{ +impl ArrayBase { /// Insert new array axis of length 1 at `axis`, modifying the shape and /// strides in-place. /// @@ -77,8 +73,7 @@ impl ArrayBase /// assert_eq!(a.shape(), &[2, 1, 3]); /// ``` #[track_caller] - pub fn insert_axis_inplace(&mut self, axis: Axis) - { + pub fn insert_axis_inplace(&mut self, axis: Axis) { self.as_mut().insert_axis_inplace(axis) } @@ -98,14 +93,14 @@ impl ArrayBase /// assert_eq!(a.shape(), &[2]); /// ``` #[track_caller] - pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) - { + pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) { self.as_mut().index_axis_inplace(axis, index) } } impl ArrayBase -where S: Data +where + S: Data, { /// Remove axes of length 1 and return the modified array. /// @@ -128,8 +123,7 @@ where S: Data /// assert_eq!(d.shape(), &[1]); /// ``` #[track_caller] - pub fn squeeze(self) -> Self - { + pub fn squeeze(self) -> Self { let mut out = self; for axis in (0..out.shape().len()).rev() { if out.shape()[axis] == 1 && out.shape().len() > 1 { @@ -141,13 +135,11 @@ where S: Data } #[cfg(test)] -mod tests -{ +mod tests { use crate::{arr1, arr2, arr3}; #[test] - fn test_squeeze() - { + fn test_squeeze() { let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn(); assert_eq!(a.shape(), &[2, 1, 3]); diff --git a/src/impl_internal_constructors.rs b/src/impl_internal_constructors.rs index ef2964ff..055da082 100644 --- a/src/impl_internal_constructors.rs +++ b/src/impl_internal_constructors.rs @@ -12,7 +12,8 @@ use crate::{imp_prelude::*, ArrayPartsSized}; // internal "builder-like" methods impl ArrayBase -where S: RawData +where + S: RawData, { /// Create an (initially) empty one-dimensional array from the given data and array head /// pointer @@ -23,8 +24,7 @@ where S: RawData /// /// See ArrayView::from_shape_ptr for general pointer validity documentation. #[inline] - pub(crate) unsafe fn from_data_ptr(data: S, ptr: NonNull) -> Self - { + pub(crate) unsafe fn from_data_ptr(data: S, ptr: NonNull) -> Self { let array = ArrayBase { data, parts: ArrayPartsSized::new(ptr, Ix1(0), Ix1(1)), @@ -51,7 +51,8 @@ where /// for the array data. #[inline] pub(crate) unsafe fn with_strides_dim(self, strides: E, dim: E) -> ArrayBase - where E: Dimension + where + E: Dimension, { debug_assert_eq!(strides.ndim(), dim.ndim()); ArrayBase { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 2170a8d9..4997da75 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -22,15 +22,8 @@ use crate::dimension::broadcast::co_broadcast; use crate::dimension::reshape_dim; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, - axes_of, - do_slice, - merge_axes, - move_min_stride_axis_to_last, - offset_from_low_addr_ptr_to_logical_ptr, - size_of_shape_checked, - stride_offset, - Axes, + abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, offset_from_low_addr_ptr_to_logical_ptr, + size_of_shape_checked, stride_offset, Axes, }; use crate::error::{self, from_kind, ErrorKind, ShapeError}; use crate::itertools::zip; @@ -45,31 +38,17 @@ use crate::RawRef; use crate::{arraytraits, DimMax}; use crate::iter::{ - AxisChunksIter, - AxisChunksIterMut, - AxisIter, - AxisIterMut, - AxisWindows, - ExactChunks, - ExactChunksMut, - IndexedIter, - IndexedIterMut, - Iter, - IterMut, - Lanes, - LanesMut, - Windows, + AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, AxisWindows, ExactChunks, ExactChunksMut, IndexedIter, + IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; use crate::slice::{MultiSliceArg, SliceArg}; use crate::stacking::concatenate; use crate::{NdIndex, Slice, SliceInfoElem}; /// # Methods For All Array Types -impl LayoutRef -{ +impl LayoutRef { /// Return the total number of elements in the array. - pub fn len(&self) -> usize - { + pub fn len(&self) -> usize { self._dim().size() } @@ -80,28 +59,24 @@ impl LayoutRef /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn len_of(&self, axis: Axis) -> usize - { + pub fn len_of(&self, axis: Axis) -> usize { self._dim()[axis.index()] } /// Return whether the array has any elements - pub fn is_empty(&self) -> bool - { + pub fn is_empty(&self) -> bool { self.len() == 0 } /// Return the number of dimensions (axes) in the array - pub fn ndim(&self) -> usize - { + pub fn ndim(&self) -> usize { self._dim().ndim() } /// Return the shape of the array in its “pattern” form, /// an integer in the one-dimensional case, tuple in the n-dimensional cases /// and so on. - pub fn dim(&self) -> D::Pattern - { + pub fn dim(&self) -> D::Pattern { self._dim().clone().into_pattern() } @@ -119,8 +94,7 @@ impl LayoutRef /// // Create an array of zeros that's the same shape and dimensionality as `a`. /// let b = Array::::zeros(a.raw_dim()); /// ``` - pub fn raw_dim(&self) -> D - { + pub fn raw_dim(&self) -> D { self._dim().clone() } @@ -148,14 +122,12 @@ impl LayoutRef /// let c = Array::zeros(a.raw_dim()); /// assert_eq!(a, c); /// ``` - pub fn shape(&self) -> &[usize] - { + pub fn shape(&self) -> &[usize] { self._dim().slice() } /// Return the strides of the array as a slice. - pub fn strides(&self) -> &[isize] - { + pub fn strides(&self) -> &[isize] { let s = self._strides().slice(); // reinterpret unsigned integer as signed unsafe { slice::from_raw_parts(s.as_ptr() as *const _, s.len()) } @@ -168,25 +140,21 @@ impl LayoutRef /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn stride_of(&self, axis: Axis) -> isize - { + pub fn stride_of(&self, axis: Axis) -> isize { // strides are reinterpreted as isize self._strides()[axis.index()] as isize } } -impl ArrayRef -{ +impl ArrayRef { /// Return a read-only view of the array - pub fn view(&self) -> ArrayView<'_, A, D> - { + pub fn view(&self) -> ArrayView<'_, A, D> { // debug_assert!(self.pointer_is_inbounds()); unsafe { ArrayView::new(*self._ptr(), self._dim().clone(), self._strides().clone()) } } /// Return a read-write view of the array - pub fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> - { + pub fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> { unsafe { ArrayViewMut::new(*self._ptr(), self._dim().clone(), self._strides().clone()) } } @@ -197,8 +165,7 @@ impl ArrayRef /// /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. - pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> - { + pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> { self.view_mut().into_cell_view() } @@ -233,7 +200,8 @@ impl ArrayRef /// # assert_eq!(arr, owned); /// ``` pub fn to_owned(&self) -> Array - where A: Clone + where + A: Clone, { if let Some(slc) = self.as_slice_memory_order() { unsafe { @@ -330,7 +298,8 @@ where /// assert_eq!(unique, array![[1., 2.], [3., 4.]]); /// ``` pub fn try_into_owned_nocopy(self) -> Result, Self> - where S: Data + where + S: Data, { S::try_into_owned_nocopy(self) } @@ -350,8 +319,7 @@ where } } -impl ArrayRef -{ +impl ArrayRef { /// Returns a reference to the first element of the array, or `None` if it /// is empty. /// @@ -367,8 +335,7 @@ impl ArrayRef /// let b = Array3::::zeros([3, 0, 5]); /// assert_eq!(b.first(), None); /// ``` - pub fn first(&self) -> Option<&A> - { + pub fn first(&self) -> Option<&A> { if self.is_empty() { None } else { @@ -391,8 +358,7 @@ impl ArrayRef /// let mut b = Array3::::zeros([3, 0, 5]); /// assert_eq!(b.first_mut(), None); /// ``` - pub fn first_mut(&mut self) -> Option<&mut A> - { + pub fn first_mut(&mut self) -> Option<&mut A> { if self.is_empty() { None } else { @@ -415,8 +381,7 @@ impl ArrayRef /// let b = Array3::::zeros([3, 0, 5]); /// assert_eq!(b.last(), None); /// ``` - pub fn last(&self) -> Option<&A> - { + pub fn last(&self) -> Option<&A> { if self.is_empty() { None } else { @@ -443,8 +408,7 @@ impl ArrayRef /// let mut b = Array3::::zeros([3, 0, 5]); /// assert_eq!(b.last_mut(), None); /// ``` - pub fn last_mut(&mut self) -> Option<&mut A> - { + pub fn last_mut(&mut self) -> Option<&mut A> { if self.is_empty() { None } else { @@ -462,8 +426,7 @@ impl ArrayRef /// is where the rightmost index is varying the fastest. /// /// Iterator element type is `&A`. - pub fn iter(&self) -> Iter<'_, A, D> - { + pub fn iter(&self) -> Iter<'_, A, D> { // debug_assert!(self.pointer_is_inbounds()); self.view().into_iter() } @@ -474,8 +437,7 @@ impl ArrayRef /// is where the rightmost index is varying the fastest. /// /// Iterator element type is `&mut A`. - pub fn iter_mut(&mut self) -> IterMut<'_, A, D> - { + pub fn iter_mut(&mut self) -> IterMut<'_, A, D> { self.view_mut().into_iter() } @@ -487,8 +449,7 @@ impl ArrayRef /// Iterator element type is `(D::Pattern, &A)`. /// /// See also [`Zip::indexed`] - pub fn indexed_iter(&self) -> IndexedIter<'_, A, D> - { + pub fn indexed_iter(&self) -> IndexedIter<'_, A, D> { IndexedIter::new(self.view().into_elements_base()) } @@ -498,8 +459,7 @@ impl ArrayRef /// is where the rightmost index is varying the fastest. /// /// Iterator element type is `(D::Pattern, &mut A)`. - pub fn indexed_iter_mut(&mut self) -> IndexedIterMut<'_, A, D> - { + pub fn indexed_iter_mut(&mut self) -> IndexedIterMut<'_, A, D> { IndexedIterMut::new(self.view_mut().into_elements_base()) } @@ -512,7 +472,8 @@ impl ArrayRef /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice(&self, info: I) -> ArrayView<'_, A, I::OutDim> - where I: SliceArg + where + I: SliceArg, { self.view().slice_move(info) } @@ -526,7 +487,8 @@ impl ArrayRef /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice_mut(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim> - where I: SliceArg + where + I: SliceArg, { self.view_mut().slice_move(info) } @@ -556,7 +518,8 @@ impl ArrayRef /// ``` #[track_caller] pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output - where M: MultiSliceArg<'a, A, D> + where + M: MultiSliceArg<'a, A, D>, { info.multi_slice_move(self.view_mut()) } @@ -576,7 +539,8 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice_move(mut self, info: I) -> ArrayBase - where I: SliceArg + where + I: SliceArg, { assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let out_ndim = info.out_ndim(); @@ -620,8 +584,7 @@ where } } -impl LayoutRef -{ +impl LayoutRef { /// Slice the array in place without changing the number of dimensions. /// /// In particular, if an axis is sliced with an index, the axis is @@ -644,7 +607,8 @@ impl LayoutRef /// - if `D` is `IxDyn` and `info` does not match the number of array axes #[track_caller] pub fn slice_collapse(&mut self, info: I) - where I: SliceArg + where + I: SliceArg, { assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let mut axis = 0; @@ -664,16 +628,14 @@ impl LayoutRef } } -impl ArrayRef -{ +impl ArrayRef { /// Return a view of the array, sliced along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. #[track_caller] #[must_use = "slice_axis returns an array view with the sliced result"] - pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D> - { + pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D> { let mut view = self.view(); view.slice_axis_inplace(axis, indices); view @@ -685,23 +647,20 @@ impl ArrayRef /// **Panics** if `axis` is out of bounds. #[track_caller] #[must_use = "slice_axis_mut returns an array view with the sliced result"] - pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D> - { + pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D> { let mut view_mut = self.view_mut(); view_mut.slice_axis_inplace(axis, indices); view_mut } } -impl LayoutRef -{ +impl LayoutRef { /// Slice the array in place along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. #[track_caller] - pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) - { + pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) { let parts = &mut self.0; let offset = do_slice(&mut parts.dim.slice_mut()[axis.index()], &mut parts.strides.slice_mut()[axis.index()], indices); @@ -722,15 +681,13 @@ where /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. #[must_use = "slice_axis_move returns an array with the sliced result"] - pub fn slice_axis_move(mut self, axis: Axis, indices: Slice) -> Self - { + pub fn slice_axis_move(mut self, axis: Axis, indices: Slice) -> Self { self.slice_axis_inplace(axis, indices); self } } -impl ArrayRef -{ +impl ArrayRef { /// Return a view of a slice of the array, with a closure specifying the /// slice for each axis. /// @@ -740,7 +697,8 @@ impl ArrayRef /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis(&self, f: F) -> ArrayView<'_, A, D> - where F: FnMut(AxisDescription) -> Slice + where + F: FnMut(AxisDescription) -> Slice, { let mut view = self.view(); view.slice_each_axis_inplace(f); @@ -756,7 +714,8 @@ impl ArrayRef /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis_mut(&mut self, f: F) -> ArrayViewMut<'_, A, D> - where F: FnMut(AxisDescription) -> Slice + where + F: FnMut(AxisDescription) -> Slice, { let mut view = self.view_mut(); view.slice_each_axis_inplace(f); @@ -764,8 +723,7 @@ impl ArrayRef } } -impl LayoutRef -{ +impl LayoutRef { /// Slice the array in place, with a closure specifying the slice for each /// axis. /// @@ -775,7 +733,8 @@ impl LayoutRef /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis_inplace(&mut self, mut f: F) - where F: FnMut(AxisDescription) -> Slice + where + F: FnMut(AxisDescription) -> Slice, { for ax in 0..self.ndim() { self.slice_axis_inplace( @@ -790,8 +749,7 @@ impl LayoutRef } } -impl ArrayRef -{ +impl ArrayRef { /// Return a reference to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -811,14 +769,14 @@ impl ArrayRef /// ); /// ``` pub fn get(&self, index: I) -> Option<&A> - where I: NdIndex + where + I: NdIndex, { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } } -impl RawRef -{ +impl RawRef { /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -833,7 +791,8 @@ impl RawRef /// assert_eq!(unsafe { *p }, 2.); /// ``` pub fn get_ptr(&self, index: I) -> Option<*const A> - where I: NdIndex + where + I: NdIndex, { let ptr = self._ptr(); index @@ -842,19 +801,18 @@ impl RawRef } } -impl ArrayRef -{ +impl ArrayRef { /// Return a mutable reference to the element at `index`, or return `None` /// if the index is out of bounds. pub fn get_mut(&mut self, index: I) -> Option<&mut A> - where I: NdIndex + where + I: NdIndex, { unsafe { self.get_mut_ptr(index).map(|ptr| &mut *ptr) } } } -impl RawRef -{ +impl RawRef { /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -873,7 +831,8 @@ impl RawRef /// assert_eq!(a.get((0, 1)), Some(&5.)); /// ``` pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> - where I: NdIndex + where + I: NdIndex, { // const and mut are separate to enforce &mutness as well as the // extra code in as_mut_ptr @@ -884,8 +843,7 @@ impl RawRef } } -impl ArrayRef -{ +impl ArrayRef { /// Perform *unchecked* array indexing. /// /// Return a reference to the element at `index`. @@ -897,7 +855,8 @@ impl ArrayRef /// The caller must ensure that the index is in-bounds. #[inline] pub unsafe fn uget(&self, index: I) -> &A - where I: NdIndex + where + I: NdIndex, { arraytraits::debug_bounds_check(self, &index); let off = index.index_unchecked(self._strides()); @@ -920,7 +879,8 @@ impl ArrayRef /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) #[inline] pub unsafe fn uget_mut(&mut self, index: I) -> &mut A - where I: NdIndex + where + I: NdIndex, { // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index); @@ -935,7 +895,8 @@ impl ArrayRef /// ***Panics*** if an index is out of bounds. #[track_caller] pub fn swap(&mut self, index1: I, index2: I) - where I: NdIndex + where + I: NdIndex, { let ptr = self.as_mut_ptr(); let offset1 = index1.index_checked(self._dim(), self._strides()); @@ -966,7 +927,8 @@ impl ArrayRef /// 2. the data is uniquely held by the array. (This property is guaranteed /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) pub unsafe fn uswap(&mut self, index1: I, index2: I) - where I: NdIndex + where + I: NdIndex, { // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index1); @@ -978,8 +940,7 @@ impl ArrayRef // `get` for zero-dimensional arrays // panics if dimension is not zero. otherwise an element is always present. - fn get_0d(&self) -> &A - { + fn get_0d(&self) -> &A { assert!(self.ndim() == 0); unsafe { &*self.as_ptr() } } @@ -1007,7 +968,8 @@ impl ArrayRef /// ``` #[track_caller] pub fn index_axis(&self, axis: Axis, index: usize) -> ArrayView<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { self.view().index_axis_move(axis, index) } @@ -1038,7 +1000,8 @@ impl ArrayRef /// ``` #[track_caller] pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ArrayViewMut<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { self.view_mut().index_axis_move(axis, index) } @@ -1056,7 +1019,8 @@ where /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] pub fn index_axis_move(mut self, axis: Axis, index: usize) -> ArrayBase - where D: RemoveAxis + where + D: RemoveAxis, { self.collapse_axis(axis, index); let dim = self.parts.dim.remove_axis(axis); @@ -1066,14 +1030,12 @@ where } } -impl LayoutRef -{ +impl LayoutRef { /// Selects `index` along the axis, collapsing the axis into length one. /// /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] - pub fn collapse_axis(&mut self, axis: Axis, index: usize) - { + pub fn collapse_axis(&mut self, axis: Axis, index: usize) { let parts = &mut self.0; let offset = dimension::do_collapse_axis(&mut parts.dim, &parts.strides, axis.index(), index); self.0.ptr = unsafe { self._ptr().offset(offset) }; @@ -1081,8 +1043,7 @@ impl LayoutRef } } -impl ArrayRef -{ +impl ArrayRef { /// Along `axis`, select arbitrary subviews corresponding to `indices` /// and copy them into a new array. /// @@ -1167,8 +1128,7 @@ impl ArrayRef /// /* loop body */ /// } /// ``` - pub fn rows(&self) -> Lanes<'_, A, D::Smaller> - { + pub fn rows(&self) -> Lanes<'_, A, D::Smaller> { let mut n = self.ndim(); if n == 0 { n += 1; @@ -1180,8 +1140,7 @@ impl ArrayRef /// rows of the array and yields mutable array views. /// /// Iterator element is `ArrayView1
` (1D read-write array view). - pub fn rows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - { + pub fn rows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> { let mut n = self.ndim(); if n == 0 { n += 1; @@ -1215,8 +1174,7 @@ impl ArrayRef /// /* loop body */ /// } /// ``` - pub fn columns(&self) -> Lanes<'_, A, D::Smaller> - { + pub fn columns(&self) -> Lanes<'_, A, D::Smaller> { Lanes::new(self.view(), Axis(0)) } @@ -1224,8 +1182,7 @@ impl ArrayRef /// columns of the array and yields mutable array views. /// /// Iterator element is `ArrayView1` (1D read-write array view). - pub fn columns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - { + pub fn columns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> { LanesMut::new(self.view_mut(), Axis(0)) } @@ -1257,8 +1214,7 @@ impl ArrayRef /// // The first lane for axis 2 is [0, 1, 2] /// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2])); /// ``` - pub fn lanes(&self, axis: Axis) -> Lanes<'_, A, D::Smaller> - { + pub fn lanes(&self, axis: Axis) -> Lanes<'_, A, D::Smaller> { Lanes::new(self.view(), axis) } @@ -1266,8 +1222,7 @@ impl ArrayRef /// pointing in the direction of `axis`. /// /// Iterator element is `ArrayViewMut1` (1D read-write array view). - pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<'_, A, D::Smaller> - { + pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<'_, A, D::Smaller> { LanesMut::new(self.view_mut(), axis) } @@ -1279,7 +1234,8 @@ impl ArrayRef /// Iterator element is `ArrayView` (read-only array view). #[allow(deprecated)] pub fn outer_iter(&self) -> AxisIter<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { self.view().into_outer_iter() } @@ -1292,7 +1248,8 @@ impl ArrayRef /// Iterator element is `ArrayViewMut` (read-write array view). #[allow(deprecated)] pub fn outer_iter_mut(&mut self) -> AxisIterMut<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { self.view_mut().into_outer_iter_mut() } @@ -1314,7 +1271,8 @@ impl ArrayRef /// #[track_caller] pub fn axis_iter(&self, axis: Axis) -> AxisIter<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIter::new(self.view(), axis) } @@ -1328,7 +1286,8 @@ impl ArrayRef /// **Panics** if `axis` is out of bounds. #[track_caller] pub fn axis_iter_mut(&mut self, axis: Axis) -> AxisIterMut<'_, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIterMut::new(self.view_mut(), axis) } @@ -1360,8 +1319,7 @@ impl ArrayRef /// [[26, 27]]])); /// ``` #[track_caller] - pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<'_, A, D> - { + pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<'_, A, D> { AxisChunksIter::new(self.view(), axis, size) } @@ -1372,8 +1330,7 @@ impl ArrayRef /// /// **Panics** if `axis` is out of bounds or if `size` is zero. #[track_caller] - pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> - { + pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> { AxisChunksIterMut::new(self.view_mut(), axis, size) } @@ -1390,7 +1347,8 @@ impl ArrayRef /// number of array axes.) #[track_caller] pub fn exact_chunks(&self, chunk_size: E) -> ExactChunks<'_, A, D> - where E: IntoDimension + where + E: IntoDimension, { ExactChunks::new(self.view(), chunk_size) } @@ -1429,7 +1387,8 @@ impl ArrayRef /// ``` #[track_caller] pub fn exact_chunks_mut(&mut self, chunk_size: E) -> ExactChunksMut<'_, A, D> - where E: IntoDimension + where + E: IntoDimension, { ExactChunksMut::new(self.view_mut(), chunk_size) } @@ -1442,7 +1401,8 @@ impl ArrayRef /// This is essentially equivalent to [`ArrayRef::windows_with_stride()`] with unit stride. #[track_caller] pub fn windows(&self, window_size: E) -> Windows<'_, A, D> - where E: IntoDimension + where + E: IntoDimension, { Windows::new(self.view(), window_size) } @@ -1493,7 +1453,8 @@ impl ArrayRef /// ``` #[track_caller] pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> - where E: IntoDimension + where + E: IntoDimension, { Windows::new_with_stride(self.view(), window_size, stride) } @@ -1519,8 +1480,7 @@ impl ArrayRef /// assert_eq!(window.shape(), &[4, 3, 2]); /// } /// ``` - pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> - { + pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> { self.axis_windows_with_stride(axis, window_size, 1) } @@ -1529,9 +1489,9 @@ impl ArrayRef /// /// Note that a calling this method with a stride of 1 is equivalent to /// calling [`ArrayRef::axis_windows()`]. - pub fn axis_windows_with_stride(&self, axis: Axis, window_size: usize, stride_size: usize) - -> AxisWindows<'_, A, D> - { + pub fn axis_windows_with_stride( + &self, axis: Axis, window_size: usize, stride_size: usize, + ) -> AxisWindows<'_, A, D> { let axis_index = axis.index(); ndassert!( @@ -1542,10 +1502,7 @@ impl ArrayRef self.shape() ); - ndassert!( - stride_size >0, - "Stride size must be greater than zero" - ); + ndassert!(stride_size > 0, "Stride size must be greater than zero"); AxisWindows::new_with_stride(self.view(), axis, window_size, stride_size) } @@ -1554,14 +1511,12 @@ impl ArrayRef /// /// The diagonal is simply the sequence indexed by *(0, 0, .., 0)*, /// *(1, 1, ..., 1)* etc as long as all axes have elements. - pub fn diag(&self) -> ArrayView1<'_, A> - { + pub fn diag(&self) -> ArrayView1<'_, A> { self.view().into_diag() } /// Return a read-write view over the diagonal elements of the array. - pub fn diag_mut(&mut self) -> ArrayViewMut1<'_, A> - { + pub fn diag_mut(&mut self) -> ArrayViewMut1<'_, A> { self.view_mut().into_diag() } } @@ -1572,8 +1527,7 @@ where D: Dimension, { // Return (length, stride) for diagonal - fn diag_params(&self) -> (Ix, Ixs) - { + fn diag_params(&self) -> (Ix, Ixs) { /* empty shape has len 1 */ let len = self.parts.dim.slice().iter().cloned().min().unwrap_or(1); let stride = self.strides().iter().sum(); @@ -1581,8 +1535,7 @@ where } /// Return the diagonal as a one-dimensional array. - pub fn into_diag(self) -> ArrayBase - { + pub fn into_diag(self) -> ArrayBase { let (len, stride) = self.diag_params(); // safe because new len stride allows access to a subset of the current elements unsafe { self.with_strides_dim(Ix1(stride as Ix), Ix1(len)) } @@ -1594,7 +1547,8 @@ where /// /// This method is mostly only useful with unsafe code. fn try_ensure_unique(&mut self) - where S: RawDataMut + where + S: RawDataMut, { debug_assert!(self.pointer_is_inbounds()); S::try_ensure_unique(self); @@ -1605,7 +1559,8 @@ where /// /// This method is mostly only useful with unsafe code. pub(crate) fn ensure_unique(&mut self) - where S: DataMut + where + S: DataMut, { debug_assert!(self.pointer_is_inbounds()); S::ensure_unique(self); @@ -1613,27 +1568,23 @@ where } } -impl LayoutRef -{ +impl LayoutRef { /// Return `true` if the array data is laid out in contiguous “C order” in /// memory (where the last index is the most rapidly varying). /// /// Return `false` otherwise, i.e. the array is possibly not /// contiguous in memory, it has custom strides, etc. - pub fn is_standard_layout(&self) -> bool - { + pub fn is_standard_layout(&self) -> bool { dimension::is_layout_c(self._dim(), self._strides()) } /// Return true if the array is known to be contiguous. - pub(crate) fn is_contiguous(&self) -> bool - { + pub(crate) fn is_contiguous(&self) -> bool { D::is_contiguous(self._dim(), self._strides()) } } -impl ArrayRef -{ +impl ArrayRef { /// Return a standard-layout array containing the data, cloning if /// necessary. /// @@ -1657,7 +1608,8 @@ impl ArrayRef /// assert!(cow_owned.is_standard_layout()); /// ``` pub fn as_standard_layout(&self) -> CowArray<'_, A, D> - where A: Clone + where + A: Clone, { if self.is_standard_layout() { CowArray::from(self.view()) @@ -1675,8 +1627,7 @@ impl ArrayRef } } -impl RawRef -{ +impl RawRef { /// Return a pointer to the first element in the array. /// /// Raw access to array elements needs to follow the strided indexing @@ -1687,15 +1638,13 @@ impl RawRef /// /// where *d* is `self.ndim()`. #[inline(always)] - pub fn as_ptr(&self) -> *const A - { + pub fn as_ptr(&self) -> *const A { self._ptr().as_ptr() as *const A } /// Return a mutable pointer to the first element in the array reference. #[inline(always)] - pub fn as_mut_ptr(&mut self) -> *mut A - { + pub fn as_mut_ptr(&mut self) -> *mut A { self._ptr().as_ptr() } } @@ -1717,26 +1666,24 @@ where /// the data may change the strides. #[inline(always)] pub fn as_mut_ptr(&mut self) -> *mut A - where S: RawDataMut + where + S: RawDataMut, { self.try_ensure_unique(); // for ArcArray self.parts.ptr.as_ptr() } } -impl RawRef -{ +impl RawRef { /// Return a raw view of the array. #[inline] - pub fn raw_view(&self) -> RawArrayView - { + pub fn raw_view(&self) -> RawArrayView { unsafe { RawArrayView::new(*self._ptr(), self._dim().clone(), self._strides().clone()) } } /// Return a raw mutable view of the array. #[inline] - pub fn raw_view_mut(&mut self) -> RawArrayViewMut - { + pub fn raw_view_mut(&mut self) -> RawArrayViewMut { unsafe { RawArrayViewMut::new(*self._ptr(), self._dim().clone(), self._strides().clone()) } } } @@ -1752,7 +1699,8 @@ where /// data is guaranteed to be uniquely held on return. #[inline] pub fn raw_view_mut(&mut self) -> RawArrayViewMut - where S: RawDataMut + where + S: RawDataMut, { self.try_ensure_unique(); // for ArcArray unsafe { RawArrayViewMut::new(self.parts.ptr, self.parts.dim.clone(), self.parts.strides.clone()) } @@ -1763,7 +1711,8 @@ where /// Safety: The caller must ensure that the owned array is unshared when this is called #[inline] pub(crate) unsafe fn raw_view_mut_unchecked(&mut self) -> RawArrayViewMut - where S: DataOwned + where + S: DataOwned, { RawArrayViewMut::new(*self._ptr(), self._dim().clone(), self._strides().clone()) } @@ -1771,7 +1720,8 @@ where /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. pub fn as_slice_mut(&mut self) -> Option<&mut [A]> - where S: DataMut + where + S: DataMut, { if self.is_standard_layout() { self.ensure_unique(); @@ -1788,7 +1738,8 @@ where /// method unshares the data if necessary, but it preserves the existing /// strides. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> - where S: DataMut + where + S: DataMut, { self.try_as_slice_memory_order_mut().ok() } @@ -1796,7 +1747,8 @@ where /// Return the array’s data as a slice if it is contiguous, otherwise /// return `self` in the `Err` variant. pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> - where S: DataMut + where + S: DataMut, { if self.is_contiguous() { self.ensure_unique(); @@ -1808,15 +1760,13 @@ where } } -impl ArrayRef -{ +impl ArrayRef { /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// /// If this function returns `Some(_)`, then the element order in the slice /// corresponds to the logical order of the array’s elements. - pub fn as_slice(&self) -> Option<&[A]> - { + pub fn as_slice(&self) -> Option<&[A]> { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self._ptr().as_ptr(), self.len())) } } else { @@ -1826,8 +1776,7 @@ impl ArrayRef /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. - pub fn as_slice_mut(&mut self) -> Option<&mut [A]> - { + pub fn as_slice_mut(&mut self) -> Option<&mut [A]> { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts_mut(self._ptr().as_ptr(), self.len())) } } else { @@ -1840,8 +1789,7 @@ impl ArrayRef /// /// If this function returns `Some(_)`, then the elements in the slice /// have whatever order the elements have in memory. - pub fn as_slice_memory_order(&self) -> Option<&[A]> - { + pub fn as_slice_memory_order(&self) -> Option<&[A]> { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(self._dim(), self._strides()); unsafe { Some(slice::from_raw_parts(self._ptr().sub(offset).as_ptr(), self.len())) } @@ -1856,15 +1804,13 @@ impl ArrayRef /// In the contiguous case, in order to return a unique reference, this /// method unshares the data if necessary, but it preserves the existing /// strides. - pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> - { + pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> { self.try_as_slice_memory_order_mut().ok() } /// Return the array’s data as a slice if it is contiguous, otherwise /// return `self` in the `Err` variant. - pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> - { + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(self._dim(), self._strides()); unsafe { Ok(slice::from_raw_parts_mut(self._ptr().sub(offset).as_ptr(), self.len())) } @@ -2027,14 +1973,16 @@ where /// ); /// ``` pub fn into_shape_with_order(self, shape: E) -> Result, ShapeError> - where E: ShapeArg + where + E: ShapeArg, { let (shape, order) = shape.into_shape_and_order(); self.into_shape_with_order_impl(shape, order.unwrap_or(Order::RowMajor)) } fn into_shape_with_order_impl(self, shape: E, order: Order) -> Result, ShapeError> - where E: Dimension + where + E: Dimension, { let shape = shape.into_dimension(); if size_of_shape_checked(&shape) != Ok(self.parts.dim.size()) { @@ -2045,10 +1993,12 @@ where unsafe { // safe because arrays are contiguous and len is unchanged match order { - Order::RowMajor if self.is_standard_layout() => - Ok(self.with_strides_dim(shape.default_strides(), shape)), - Order::ColumnMajor if self.raw_view().reversed_axes().is_standard_layout() => - Ok(self.with_strides_dim(shape.fortran_strides(), shape)), + Order::RowMajor if self.is_standard_layout() => { + Ok(self.with_strides_dim(shape.default_strides(), shape)) + } + Order::ColumnMajor if self.raw_view().reversed_axes().is_standard_layout() => { + Ok(self.with_strides_dim(shape.fortran_strides(), shape)) + } _otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout)), } } @@ -2079,7 +2029,8 @@ where /// ``` #[deprecated(note = "Use `.into_shape_with_order()` or `.to_shape()`", since = "0.16.0")] pub fn into_shape(self, shape: E) -> Result, ShapeError> - where E: IntoDimension + where + E: IntoDimension, { let shape = shape.into_dimension(); if size_of_shape_checked(&shape) != Ok(self.parts.dim.size()) { @@ -2213,8 +2164,7 @@ where } } -impl ArrayRef -{ +impl ArrayRef { /// Flatten the array to a one-dimensional array. /// /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. @@ -2227,7 +2177,8 @@ impl ArrayRef /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); /// ``` pub fn flatten(&self) -> CowArray<'_, A, Ix1> - where A: Clone + where + A: Clone, { self.flatten_with_order(Order::RowMajor) } @@ -2248,7 +2199,8 @@ impl ArrayRef /// assert_eq!(flattened, arr1(&[1, 3, 5, 7, 2, 4, 6, 8])); /// ``` pub fn flatten_with_order(&self, order: Order) -> CowArray<'_, A, Ix1> - where A: Clone + where + A: Clone, { self.to_shape((self.len(), order)).unwrap() } @@ -2289,8 +2241,7 @@ where /// let array: ArrayD = arr2(&[[1, 2], /// [3, 4]]).into_dyn(); /// ``` - pub fn into_dyn(self) -> ArrayBase - { + pub fn into_dyn(self) -> ArrayBase { // safe because new dims equivalent unsafe { ArrayBase::from_data_ptr(self.data, self.parts.ptr) @@ -2315,7 +2266,8 @@ where /// assert!(array.into_dimensionality::().is_ok()); /// ``` pub fn into_dimensionality(self) -> Result, ShapeError> - where D2: Dimension + where + D2: Dimension, { unsafe { if D::NDIM == D2::NDIM { @@ -2337,8 +2289,7 @@ where } } -impl ArrayRef -{ +impl ArrayRef { /// Act like a larger size and/or shape array by *broadcasting* /// into a larger shape, if possible. /// @@ -2369,7 +2320,8 @@ impl ArrayRef /// ); /// ``` pub fn broadcast(&self, dim: E) -> Option> - where E: IntoDimension + where + E: IntoDimension, { /// Return new stride when trying to grow `from` into shape `to` /// @@ -2379,8 +2331,7 @@ impl ArrayRef /// /// **Note:** Cannot be used for mutable iterators, since repeating /// elements would create aliasing pointers. - fn upcast(to: &D, from: &E, stride: &E) -> Option - { + fn upcast(to: &D, from: &E, stride: &E) -> Option { // Make sure the product of non-zero axis lengths does not exceed // `isize::MAX`. This is the only safety check we need to perform // because all the other constraints of `ArrayBase` are guaranteed @@ -2465,8 +2416,7 @@ impl ArrayRef } } -impl LayoutRef -{ +impl LayoutRef { /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2484,8 +2434,7 @@ impl LayoutRef /// ); /// ``` #[track_caller] - pub fn swap_axes(&mut self, ax: usize, bx: usize) - { + pub fn swap_axes(&mut self, ax: usize, bx: usize) { self.0.dim.slice_mut().swap(ax, bx); self.0.strides.slice_mut().swap(ax, bx); } @@ -2520,7 +2469,8 @@ where /// ``` #[track_caller] pub fn permuted_axes(self, axes: T) -> ArrayBase - where T: IntoDimension + where + T: IntoDimension, { let axes = axes.into_dimension(); // Ensure that each axis is used exactly once. @@ -2571,7 +2521,8 @@ where /// ``` #[track_caller] pub fn permute_axes(&mut self, axes: T) - where T: IntoDimension + where + T: IntoDimension, { let axes = axes.into_dimension(); // Ensure that each axis is used exactly once. @@ -2616,8 +2567,7 @@ where /// /// Transposition reverses the order of the axes (dimensions and strides) /// while retaining the same data. - pub fn reversed_axes(mut self) -> ArrayBase - { + pub fn reversed_axes(mut self) -> ArrayBase { self.parts.dim.slice_mut().reverse(); self.parts.strides.slice_mut().reverse(); self @@ -2627,31 +2577,26 @@ where /// /// This does not move any data, it just adjusts the array's dimensions /// and strides. - pub fn reverse_axes(&mut self) - { + pub fn reverse_axes(&mut self) { self.parts.dim.slice_mut().reverse(); self.parts.strides.slice_mut().reverse(); } } -impl ArrayRef -{ +impl ArrayRef { /// Return a transposed view of the array. /// /// This is a shorthand for `self.view().reversed_axes()`. /// /// See also the more general methods `.reversed_axes()` and `.swap_axes()`. - pub fn t(&self) -> ArrayView<'_, A, D> - { + pub fn t(&self) -> ArrayView<'_, A, D> { self.view().reversed_axes() } } -impl LayoutRef -{ +impl LayoutRef { /// Return an iterator over the length and stride of each axis. - pub fn axes(&self) -> Axes<'_, D> - { + pub fn axes(&self) -> Axes<'_, D> { axes_of(self._dim(), self._strides()) } @@ -2664,8 +2609,7 @@ impl LayoutRef /// Return the axis with the greatest stride (by absolute value), /// preferring axes with len > 1. - pub fn max_stride_axis(&self) -> Axis - { + pub fn max_stride_axis(&self) -> Axis { self._dim().max_stride_axis(self._strides()) } @@ -2673,8 +2617,7 @@ impl LayoutRef /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn invert_axis(&mut self, axis: Axis) - { + pub fn invert_axis(&mut self, axis: Axis) { unsafe { let s = self._strides().axis(axis) as Ixs; let m = self._dim().axis(axis); @@ -2721,8 +2664,7 @@ impl LayoutRef /// /// ***Panics*** if an axis is out of bounds. #[track_caller] - pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool - { + pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool { let parts = &mut self.0; merge_axes(&mut parts.dim, &mut parts.strides, take, into) } @@ -2755,8 +2697,7 @@ where /// /// ***Panics*** if the axis is out of bounds. #[track_caller] - pub fn insert_axis(self, axis: Axis) -> ArrayBase - { + pub fn insert_axis(self, axis: Axis) -> ArrayBase { assert!(axis.index() <= self.ndim()); // safe because a new axis of length one does not affect memory layout unsafe { @@ -2774,19 +2715,18 @@ where /// **Panics** if the axis is out of bounds or its length is zero. #[track_caller] pub fn remove_axis(self, axis: Axis) -> ArrayBase - where D: RemoveAxis + where + D: RemoveAxis, { self.index_axis_move(axis, 0) } - pub(crate) fn pointer_is_inbounds(&self) -> bool - { + pub(crate) fn pointer_is_inbounds(&self) -> bool { self.data._is_pointer_inbounds(self.as_ptr()) } } -impl ArrayRef -{ +impl ArrayRef { /// Perform an elementwise assignment to `self` from `rhs`. /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. @@ -2794,7 +2734,8 @@ impl ArrayRef /// **Panics** if broadcasting isn’t possible. #[track_caller] pub fn assign(&mut self, rhs: &ArrayRef) - where A: Clone + where + A: Clone, { self.zip_mut_with(rhs, |x, y| x.clone_from(y)); } @@ -2817,7 +2758,8 @@ impl ArrayRef /// Perform an elementwise assignment to `self` from element `x`. pub fn fill(&mut self, x: A) - where A: Clone + where + A: Clone, { self.map_inplace(move |elt| elt.clone_from(&x)); } @@ -2866,7 +2808,8 @@ impl ArrayRef } fn zip_mut_with_elem(&mut self, rhs_elem: &B, mut f: F) - where F: FnMut(&mut A, &B) + where + F: FnMut(&mut A, &B), { self.map_inplace(move |elt| f(elt, rhs_elem)); } @@ -3058,8 +3001,7 @@ where } } -impl ArrayRef -{ +impl ArrayRef { /// Modify the array in place by calling `f` by mutable reference on each element. /// /// Elements are visited in arbitrary order. @@ -3199,8 +3141,7 @@ impl ArrayRef /// /// ***Panics*** if `axis` is out of bounds
/// ***Panics*** if not `index < self.len_of(axis)`. - pub fn remove_index(&mut self, axis: Axis, index: usize) - { + pub fn remove_index(&mut self, axis: Axis, index: usize) { assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index()); let (_, mut tail) = self.view_mut().split_at(axis, index); // shift elements to the front @@ -3210,8 +3151,7 @@ impl ArrayRef } } -impl ArrayRef -{ +impl ArrayRef { /// Iterates over pairs of consecutive elements along the axis. /// /// The first argument to the closure is an element, and the second @@ -3241,7 +3181,8 @@ impl ArrayRef /// ); /// ``` pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) - where F: FnMut(&A, &mut A) + where + F: FnMut(&A, &mut A), { if self.len_of(axis) <= 1 { return; @@ -3357,8 +3298,7 @@ impl ArrayRef /// **Panics** if the size of A and B are different. #[track_caller] #[inline] -unsafe fn unlimited_transmute(data: A) -> B -{ +unsafe fn unlimited_transmute(data: A) -> B { // safe when sizes are equal and caller guarantees that representations are equal assert_eq!(size_of::
(), size_of::()); let old_data = ManuallyDrop::new(data); @@ -3368,23 +3308,20 @@ unsafe fn unlimited_transmute(data: A) -> B type DimMaxOf = >::Output; #[cfg(test)] -mod tests -{ +mod tests { use super::*; use crate::arr3; use defmac::defmac; #[test] - fn test_flatten() - { + fn test_flatten() { let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); let flattened = array.flatten(); assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); } #[test] - fn test_flatten_with_order() - { + fn test_flatten_with_order() { let array = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8]]); let flattened = array.flatten_with_order(Order::RowMajor); assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); @@ -3393,16 +3330,14 @@ mod tests } #[test] - fn test_into_flat() - { + fn test_into_flat() { let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); let flattened = array.into_flat(); assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); } #[test] - fn test_first_last() - { + fn test_first_last() { let first = 2; let last = 3; @@ -3442,8 +3377,7 @@ mod tests } #[test] - fn test_partition_1d() - { + fn test_partition_1d() { // Test partitioning a 1D array let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]); let result = array.partition(3, Axis(0)); @@ -3453,8 +3387,7 @@ mod tests } #[test] - fn test_partition_2d() - { + fn test_partition_2d() { // Test partitioning a 2D array along both axes let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]); @@ -3474,8 +3407,7 @@ mod tests } #[test] - fn test_partition_3d() - { + fn test_partition_3d() { // Test partitioning a 3D array let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]); @@ -3490,8 +3422,7 @@ mod tests #[test] #[should_panic] - fn test_partition_invalid_kth() - { + fn test_partition_invalid_kth() { let a = array![1, 2, 3, 4]; // This should panic because kth=4 is out of bounds let _ = a.partition(4, Axis(0)); @@ -3499,16 +3430,14 @@ mod tests #[test] #[should_panic] - fn test_partition_invalid_axis() - { + fn test_partition_invalid_axis() { let a = array![1, 2, 3, 4]; // This should panic because axis=1 is out of bounds for a 1D array let _ = a.partition(0, Axis(1)); } #[test] - fn test_partition_contiguous_or_not() - { + fn test_partition_contiguous_or_not() { // Test contiguous case (C-order) let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; @@ -3567,8 +3496,7 @@ mod tests } #[test] - fn test_partition_empty() - { + fn test_partition_empty() { // Test 1D empty array let empty1d = Array1::::zeros(0); let result1d = empty1d.partition(0, Axis(0)); diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 53f49cc4..b133a651 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -404,25 +404,21 @@ impl<'a, D> $trt<&'a ArrayRef<$scalar, D>> for $scalar ); } -mod arithmetic_ops -{ +mod arithmetic_ops { use super::*; use crate::imp_prelude::*; use std::ops::*; - fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C - { + fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C { move |x, y| f(x.clone(), y.clone()) } - fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) - { + fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) { move |x, y| *x = f(x.clone(), y.clone()) } - fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) - { + fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) { move |x, y| *x = f(y.clone(), x.clone()) } @@ -499,8 +495,7 @@ mod arithmetic_ops type Output = Self; /// Perform an elementwise negation of `self` and return the result. - fn neg(mut self) -> Self - { + fn neg(mut self) -> Self { self.map_inplace(|elt| { *elt = -elt.clone(); }); @@ -518,8 +513,7 @@ mod arithmetic_ops /// Perform an elementwise negation of reference `self` and return the /// result as a new `Array`. - fn neg(self) -> Array - { + fn neg(self) -> Array { (&**self).neg() } } @@ -533,8 +527,7 @@ mod arithmetic_ops /// Perform an elementwise negation of reference `self` and return the /// result as a new `Array`. - fn neg(self) -> Array - { + fn neg(self) -> Array { self.map(Neg::neg) } } @@ -548,8 +541,7 @@ mod arithmetic_ops type Output = Self; /// Perform an elementwise unary not of `self` and return the result. - fn not(mut self) -> Self - { + fn not(mut self) -> Self { self.map_inplace(|elt| { *elt = !elt.clone(); }); @@ -567,8 +559,7 @@ mod arithmetic_ops /// Perform an elementwise unary not of reference `self` and return the /// result as a new `Array`. - fn not(self) -> Array - { + fn not(self) -> Array { (&**self).not() } } @@ -582,15 +573,13 @@ mod arithmetic_ops /// Perform an elementwise unary not of reference `self` and return the /// result as a new `Array`. - fn not(self) -> Array - { + fn not(self) -> Array { self.map(Not::not) } } } -mod assign_ops -{ +mod assign_ops { use super::*; use crate::imp_prelude::*; @@ -661,54 +650,14 @@ mod assign_ops }; } - impl_assign_op!( - AddAssign, - add_assign, - "Perform `self += rhs` as elementwise addition (in place).\n" - ); - impl_assign_op!( - SubAssign, - sub_assign, - "Perform `self -= rhs` as elementwise subtraction (in place).\n" - ); - impl_assign_op!( - MulAssign, - mul_assign, - "Perform `self *= rhs` as elementwise multiplication (in place).\n" - ); - impl_assign_op!( - DivAssign, - div_assign, - "Perform `self /= rhs` as elementwise division (in place).\n" - ); - impl_assign_op!( - RemAssign, - rem_assign, - "Perform `self %= rhs` as elementwise remainder (in place).\n" - ); - impl_assign_op!( - BitAndAssign, - bitand_assign, - "Perform `self &= rhs` as elementwise bit and (in place).\n" - ); - impl_assign_op!( - BitOrAssign, - bitor_assign, - "Perform `self |= rhs` as elementwise bit or (in place).\n" - ); - impl_assign_op!( - BitXorAssign, - bitxor_assign, - "Perform `self ^= rhs` as elementwise bit xor (in place).\n" - ); - impl_assign_op!( - ShlAssign, - shl_assign, - "Perform `self <<= rhs` as elementwise left shift (in place).\n" - ); - impl_assign_op!( - ShrAssign, - shr_assign, - "Perform `self >>= rhs` as elementwise right shift (in place).\n" - ); + impl_assign_op!(AddAssign, add_assign, "Perform `self += rhs` as elementwise addition (in place).\n"); + impl_assign_op!(SubAssign, sub_assign, "Perform `self -= rhs` as elementwise subtraction (in place).\n"); + impl_assign_op!(MulAssign, mul_assign, "Perform `self *= rhs` as elementwise multiplication (in place).\n"); + impl_assign_op!(DivAssign, div_assign, "Perform `self /= rhs` as elementwise division (in place).\n"); + impl_assign_op!(RemAssign, rem_assign, "Perform `self %= rhs` as elementwise remainder (in place).\n"); + impl_assign_op!(BitAndAssign, bitand_assign, "Perform `self &= rhs` as elementwise bit and (in place).\n"); + impl_assign_op!(BitOrAssign, bitor_assign, "Perform `self |= rhs` as elementwise bit or (in place).\n"); + impl_assign_op!(BitXorAssign, bitxor_assign, "Perform `self ^= rhs` as elementwise bit xor (in place).\n"); + impl_assign_op!(ShlAssign, shl_assign, "Perform `self <<= rhs` as elementwise left shift (in place).\n"); + impl_assign_op!(ShrAssign, shr_assign, "Perform `self >>= rhs` as elementwise right shift (in place).\n"); } diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index fb06f965..73165500 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -19,8 +19,7 @@ use crate::Zip; /// Methods specific to `Array0`. /// /// ***See also all methods for [`ArrayBase`]*** -impl Array -{ +impl Array { /// Returns the single element in the array without cloning it. /// /// ``` @@ -34,8 +33,7 @@ impl Array /// let scalar: Foo = array.into_scalar(); /// assert_eq!(scalar, Foo); /// ``` - pub fn into_scalar(self) -> A - { + pub fn into_scalar(self) -> A { let size = mem::size_of::(); if size == 0 { // Any index in the `Vec` is fine since all elements are identical. @@ -59,12 +57,12 @@ impl Array /// /// ***See also all methods for [`ArrayBase`]*** impl Array -where D: Dimension +where + D: Dimension, { /// Returns the offset (in units of `A`) from the start of the allocation /// to the first element, or `None` if the array is empty. - fn offset_from_alloc_to_logical_ptr(&self) -> Option - { + fn offset_from_alloc_to_logical_ptr(&self) -> Option { if self.is_empty() { return None; } @@ -140,8 +138,7 @@ where D: Dimension /// } /// } /// ``` - pub fn into_raw_vec_and_offset(self) -> (Vec, Option) - { + pub fn into_raw_vec_and_offset(self) -> (Vec, Option) { let offset = self.offset_from_alloc_to_logical_ptr(); (self.data.into_vec(), offset) } @@ -153,8 +150,7 @@ where D: Dimension /// array can be located at an offset. Because of this, prefer to use /// `.into_raw_vec_and_offset()` instead. #[deprecated(note = "Use .into_raw_vec_and_offset() instead", since = "0.16.0")] - pub fn into_raw_vec(self) -> Vec - { + pub fn into_raw_vec(self) -> Vec { self.into_raw_vec_and_offset().0 } } @@ -162,8 +158,7 @@ where D: Dimension /// Methods specific to `Array2`. /// /// ***See also all methods for [`ArrayBase`]*** -impl Array -{ +impl Array { /// Append a row to an array /// /// The elements from `row` are cloned and added as a new row in the array. @@ -204,7 +199,8 @@ impl Array /// [-1., -2., -3., -4.]]); /// ``` pub fn push_row(&mut self, row: ArrayView) -> Result<(), ShapeError> - where A: Clone + where + A: Clone, { self.append(Axis(0), row.insert_axis(Axis(0))) } @@ -249,7 +245,8 @@ impl Array /// [2., -2.]]); /// ``` pub fn push_column(&mut self, column: ArrayView) -> Result<(), ShapeError> - where A: Clone + where + A: Clone, { self.append(Axis(1), column.insert_axis(Axis(1))) } @@ -272,8 +269,7 @@ impl Array /// a.reserve_rows(1000).unwrap(); /// assert!(a.into_raw_vec().capacity() >= 4*1002); /// ``` - pub fn reserve_rows(&mut self, additional: usize) -> Result<(), ShapeError> - { + pub fn reserve_rows(&mut self, additional: usize) -> Result<(), ShapeError> { self.reserve(Axis(0), additional) } @@ -295,14 +291,14 @@ impl Array /// a.reserve_columns(1000).unwrap(); /// assert!(a.into_raw_vec().capacity() >= 2*1002); /// ``` - pub fn reserve_columns(&mut self, additional: usize) -> Result<(), ShapeError> - { + pub fn reserve_columns(&mut self, additional: usize) -> Result<(), ShapeError> { self.reserve(Axis(1), additional) } } impl Array -where D: Dimension +where + D: Dimension, { /// Move all elements from self into `new_array`, which must be of the same shape but /// can have a different memory layout. The destination is overwritten completely. @@ -338,8 +334,7 @@ where D: Dimension } } - fn move_into_needs_drop(mut self, new_array: ArrayViewMut) - { + fn move_into_needs_drop(mut self, new_array: ArrayViewMut) { // Simple case where `A` has a destructor: just swap values between self and new_array. // Afterwards, `self` drops full of initialized values and dropping works as usual. // This avoids moving out of owned values in `self` while at the same time managing @@ -384,8 +379,7 @@ where D: Dimension self.move_into_impl(new_array.into()) } - fn move_into_impl(mut self, new_array: ArrayViewMut, D>) - { + fn move_into_impl(mut self, new_array: ArrayViewMut, D>) { unsafe { // Safety: copy_to_nonoverlapping cannot panic let guard = AbortIfPanic(&"move_into: moving out of owned value"); @@ -410,8 +404,7 @@ where D: Dimension /// # Safety /// /// This is a panic critical section since `self` is already moved-from. - fn drop_unreachable_elements(mut self) -> OwnedRepr - { + fn drop_unreachable_elements(mut self) -> OwnedRepr { let self_len = self.len(); // "deconstruct" self; the owned repr releases ownership of all elements and we @@ -431,8 +424,7 @@ where D: Dimension #[inline(never)] #[cold] - fn drop_unreachable_elements_slow(mut self) -> OwnedRepr - { + fn drop_unreachable_elements_slow(mut self) -> OwnedRepr { // "deconstruct" self; the owned repr releases ownership of all elements and we // carry on with raw view methods let data_len = self.data.len(); @@ -453,8 +445,7 @@ where D: Dimension /// Create an empty array with an all-zeros shape /// /// ***Panics*** if D is zero-dimensional, because it can't be empty - pub(crate) fn empty() -> Array - { + pub(crate) fn empty() -> Array { assert_ne!(D::NDIM, Some(0)); let ndim = D::NDIM.unwrap_or(1); Array::from_shape_simple_fn(D::zeros(ndim), || unreachable!()) @@ -462,8 +453,7 @@ where D: Dimension /// Create new_array with the right layout for appending to `growing_axis` #[cold] - fn change_to_contig_append_layout(&mut self, growing_axis: Axis) - { + fn change_to_contig_append_layout(&mut self, growing_axis: Axis) { let ndim = self.ndim(); let mut dim = self.raw_dim(); @@ -744,25 +734,25 @@ where D: Dimension if tail_view.ndim() > 1 { sort_axes_in_default_order_tandem(&mut tail_view, &mut array); - debug_assert!(tail_view.is_standard_layout(), - "not std layout dim: {:?}, strides: {:?}", - tail_view.shape(), RawArrayViewMut::strides(&tail_view)); + debug_assert!( + tail_view.is_standard_layout(), + "not std layout dim: {:?}, strides: {:?}", + tail_view.shape(), + RawArrayViewMut::strides(&tail_view) + ); } // Keep track of currently filled length of `self.data` and update it // on scope exit (panic or loop finish). This "indirect" way to // write the length is used to help the compiler, the len store to self.data may // otherwise be mistaken to alias with other stores in the loop. - struct SetLenOnDrop<'a, A: 'a> - { + struct SetLenOnDrop<'a, A: 'a> { len: usize, data: &'a mut OwnedRepr, } - impl Drop for SetLenOnDrop<'_, A> - { - fn drop(&mut self) - { + impl Drop for SetLenOnDrop<'_, A> { + fn drop(&mut self) { unsafe { self.data.set_len(self.len); } @@ -821,7 +811,8 @@ where D: Dimension /// ``` /// pub fn reserve(&mut self, axis: Axis, additional: usize) -> Result<(), ShapeError> - where D: RemoveAxis + where + D: RemoveAxis, { debug_assert!(axis.index() < self.ndim()); let self_dim = self.raw_dim(); @@ -867,9 +858,9 @@ where D: Dimension /// /// This is an internal function for use by move_into and IntoIter only, safety invariants may need /// to be upheld across the calls from those implementations. -pub(crate) unsafe fn drop_unreachable_raw( - mut self_: RawArrayViewMut, data_ptr: NonNull, data_len: usize, -) where D: Dimension +pub(crate) unsafe fn drop_unreachable_raw(mut self_: RawArrayViewMut, data_ptr: NonNull, data_len: usize) +where + D: Dimension, { let self_len = self_.len(); @@ -933,8 +924,7 @@ pub(crate) unsafe fn drop_unreachable_raw( dropped_elements += 1; } - assert_eq!(data_len, dropped_elements + self_len, - "Internal error: inconsistency in move_into"); + assert_eq!(data_len, dropped_elements + self_len, "Internal error: inconsistency in move_into"); } /// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride @@ -952,7 +942,8 @@ where } fn sort_axes1_impl(adim: &mut D, astrides: &mut D) -where D: Dimension +where + D: Dimension, { debug_assert!(adim.ndim() > 1); debug_assert_eq!(adim.ndim(), astrides.ndim()); @@ -992,7 +983,8 @@ where } fn sort_axes2_impl(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D) -where D: Dimension +where + D: Dimension, { debug_assert!(adim.ndim() > 1); debug_assert_eq!(adim.ndim(), bdim.ndim()); diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 2423b934..d18b43cb 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -9,21 +9,20 @@ use crate::is_aligned; use crate::shape_builder::{StrideShape, Strides}; impl RawArrayView -where D: Dimension +where + D: Dimension, { /// Create a new `RawArrayView`. /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. #[inline] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { RawArrayView::from_data_ptr(RawViewRepr::new(), ptr).with_strides_dim(strides, dim) } #[inline] - unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self - { + unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) } @@ -69,7 +68,8 @@ where D: Dimension /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *const A) -> Self - where Sh: Into> + where + Sh: Into>, { let shape = shape.into(); let dim = shape.dim; @@ -95,12 +95,8 @@ where D: Dimension /// data is valid, ensure that the pointer is aligned, and choose the /// correct lifetime. #[inline] - pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> - { - debug_assert!( - is_aligned(self.parts.ptr.as_ptr()), - "The pointer must be aligned." - ); + pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { + debug_assert!(is_aligned(self.parts.ptr.as_ptr()), "The pointer must be aligned."); ArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } @@ -110,8 +106,7 @@ where D: Dimension /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] #[inline] - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) - { + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { assert!(index <= self.len_of(axis)); let left_ptr = self.parts.ptr.as_ptr(); let right_ptr = if index == self.len_of(axis) { @@ -145,31 +140,23 @@ where D: Dimension /// casts are safe, access through the produced raw view is only possible /// in an unsafe block or function. #[track_caller] - pub fn cast(self) -> RawArrayView - { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "size mismatch in raw view cast" - ); + pub fn cast(self) -> RawArrayView { + assert_eq!(mem::size_of::(), mem::size_of::(), "size mismatch in raw view cast"); let ptr = self.parts.ptr.cast::(); unsafe { RawArrayView::new(ptr, self.parts.dim, self.parts.strides) } } } impl RawArrayView, D> -where D: Dimension +where + D: Dimension, { /// Splits the view into views of the real and imaginary components of the /// elements. - pub fn split_complex(self) -> Complex> - { + pub fn split_complex(self) -> Complex> { // Check that the size and alignment of `Complex` are as expected. // These assertions should always pass, for arbitrary `T`. - assert_eq!( - mem::size_of::>(), - mem::size_of::().checked_mul(2).unwrap() - ); + assert_eq!(mem::size_of::>(), mem::size_of::().checked_mul(2).unwrap()); assert_eq!(mem::align_of::>(), mem::align_of::()); let dim = self.parts.dim.clone(); @@ -225,21 +212,20 @@ where D: Dimension } impl RawArrayViewMut -where D: Dimension +where + D: Dimension, { /// Create a new `RawArrayViewMut`. /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. #[inline] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { RawArrayViewMut::from_data_ptr(RawViewRepr::new(), ptr).with_strides_dim(strides, dim) } #[inline] - unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self - { + unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) } @@ -285,7 +271,8 @@ where D: Dimension /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *mut A) -> Self - where Sh: Into> + where + Sh: Into>, { let shape = shape.into(); let dim = shape.dim; @@ -294,8 +281,10 @@ where D: Dimension if let Strides::Custom(strides) = &shape.strides { dimension::strides_non_negative(strides).unwrap(); dimension::max_abs_offset_check_overflow::(&dim, strides).unwrap(); - assert!(!dimension::dim_stride_overlap(&dim, strides), - "The strides must not allow any element to be referenced by two different indices"); + assert!( + !dimension::dim_stride_overlap(&dim, strides), + "The strides must not allow any element to be referenced by two different indices" + ); } else { dimension::size_of_shape_checked(&dim).unwrap(); } @@ -306,8 +295,7 @@ where D: Dimension /// Converts to a non-mutable `RawArrayView`. #[inline] - pub(crate) fn into_raw_view(self) -> RawArrayView - { + pub(crate) fn into_raw_view(self) -> RawArrayView { unsafe { RawArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } @@ -320,12 +308,8 @@ where D: Dimension /// data is valid, ensure that the pointer is aligned, and choose the /// correct lifetime. #[inline] - pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> - { - debug_assert!( - is_aligned(self.parts.ptr.as_ptr()), - "The pointer must be aligned." - ); + pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { + debug_assert!(is_aligned(self.parts.ptr.as_ptr()), "The pointer must be aligned."); ArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } @@ -338,12 +322,8 @@ where D: Dimension /// data is valid, ensure that the pointer is aligned, and choose the /// correct lifetime. #[inline] - pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> - { - debug_assert!( - is_aligned(self.parts.ptr.as_ptr()), - "The pointer must be aligned." - ); + pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> { + debug_assert!(is_aligned(self.parts.ptr.as_ptr()), "The pointer must be aligned."); ArrayViewMut::new(self.parts.ptr, self.parts.dim, self.parts.strides) } @@ -353,8 +333,7 @@ where D: Dimension /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] #[inline] - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) - { + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { let (left, right) = self.into_raw_view().split_at(axis, index); unsafe { ( @@ -375,25 +354,20 @@ where D: Dimension /// casts are safe, access through the produced raw view is only possible /// in an unsafe block or function. #[track_caller] - pub fn cast(self) -> RawArrayViewMut - { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "size mismatch in raw view cast" - ); + pub fn cast(self) -> RawArrayViewMut { + assert_eq!(mem::size_of::(), mem::size_of::(), "size mismatch in raw view cast"); let ptr = self.parts.ptr.cast::(); unsafe { RawArrayViewMut::new(ptr, self.parts.dim, self.parts.strides) } } } impl RawArrayViewMut, D> -where D: Dimension +where + D: Dimension, { /// Splits the view into views of the real and imaginary components of the /// elements. - pub fn split_complex(self) -> Complex> - { + pub fn split_complex(self) -> Complex> { let Complex { re, im } = self.into_raw_view().split_complex(); unsafe { Complex { diff --git a/src/impl_ref_types.rs b/src/impl_ref_types.rs index 108ac68b..078e0794 100644 --- a/src/impl_ref_types.rs +++ b/src/impl_ref_types.rs @@ -36,28 +36,18 @@ use core::{ }; use crate::{ - Array, - ArrayBase, - ArrayPartsSized, - ArrayPartsUnsized, - ArrayRef, - Data, - DataMut, - Dimension, - LayoutRef, - RawData, - RawDataMut, - RawRef, + Array, ArrayBase, ArrayPartsSized, ArrayPartsUnsized, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, + RawDataMut, RawRef, }; // D1: &ArrayBase -> &ArrayRef when data is safe to read impl Deref for ArrayBase -where S: Data +where + S: Data, { type Target = ArrayRef; - fn deref(&self) -> &Self::Target - { + fn deref(&self) -> &Self::Target { // SAFETY: // - The pointer is aligned because neither type uses repr(align) // - It is "dereferencable" because it comes from a reference @@ -75,8 +65,7 @@ where S: DataMut, D: Dimension, { - fn deref_mut(&mut self) -> &mut Self::Target - { + fn deref_mut(&mut self) -> &mut Self::Target { self.ensure_unique(); // SAFETY: // - The pointer is aligned because neither type uses repr(align) @@ -90,12 +79,10 @@ where } // D3: &ArrayRef -> &RawRef -impl Deref for ArrayRef -{ +impl Deref for ArrayRef { type Target = RawRef; - fn deref(&self) -> &Self::Target - { + fn deref(&self) -> &Self::Target { // SAFETY: // - The pointer is aligned because neither type uses repr(align) // - It is "dereferencable" because it comes from a reference @@ -106,10 +93,8 @@ impl Deref for ArrayRef } // D4: &mut ArrayRef -> &mut RawRef -impl DerefMut for ArrayRef -{ - fn deref_mut(&mut self) -> &mut Self::Target - { +impl DerefMut for ArrayRef { + fn deref_mut(&mut self) -> &mut Self::Target { // SAFETY: // - The pointer is aligned because neither type uses repr(align) // - It is "dereferencable" because it comes from a reference @@ -120,31 +105,27 @@ impl DerefMut for ArrayRef } // D5: &RawRef -> &LayoutRef -impl Deref for RawRef -{ +impl Deref for RawRef { type Target = LayoutRef; - fn deref(&self) -> &Self::Target - { + fn deref(&self) -> &Self::Target { &self.0 } } // D5: &mut RawRef -> &mut LayoutRef -impl DerefMut for RawRef -{ - fn deref_mut(&mut self) -> &mut Self::Target - { +impl DerefMut for RawRef { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } // A1: &ArrayBase -AR-> &RawRef impl AsRef> for ArrayBase -where S: RawData +where + S: RawData, { - fn as_ref(&self) -> &RawRef - { + fn as_ref(&self) -> &RawRef { // SAFETY: // - The pointer is aligned because neither type uses repr(align) // - It is "dereferencable" because it comes from a reference @@ -158,10 +139,10 @@ where S: RawData // A2: &mut ArrayBase -AM-> &mut RawRef impl AsMut> for ArrayBase -where S: RawDataMut +where + S: RawDataMut, { - fn as_mut(&mut self) -> &mut RawRef - { + fn as_mut(&mut self) -> &mut RawRef { // SAFETY: // - The pointer is aligned because neither type uses repr(align) // - It is "dereferencable" because it comes from a reference @@ -175,10 +156,10 @@ where S: RawDataMut // A3: &ArrayBase -AR-> &LayoutRef impl AsRef> for ArrayBase -where S: RawData +where + S: RawData, { - fn as_ref(&self) -> &LayoutRef - { + fn as_ref(&self) -> &LayoutRef { let parts: &ArrayPartsUnsized = &self.parts; let ptr = (parts as *const ArrayPartsUnsized) as *const LayoutRef; unsafe { &*ptr } @@ -187,10 +168,10 @@ where S: RawData // A3: &mut ArrayBase -AM-> &mut LayoutRef impl AsMut> for ArrayBase -where S: RawData +where + S: RawData, { - fn as_mut(&mut self) -> &mut LayoutRef - { + fn as_mut(&mut self) -> &mut LayoutRef { let parts: &mut ArrayPartsUnsized = &mut self.parts; let ptr = (parts as *mut ArrayPartsUnsized) as *mut LayoutRef; unsafe { &mut *ptr } @@ -198,91 +179,71 @@ where S: RawData } // A4: &ArrayRef -AR-> &RawRef -impl AsRef> for ArrayRef -{ - fn as_ref(&self) -> &RawRef - { +impl AsRef> for ArrayRef { + fn as_ref(&self) -> &RawRef { self } } // A4: &mut ArrayRef -AM-> &mut RawRef -impl AsMut> for ArrayRef -{ - fn as_mut(&mut self) -> &mut RawRef - { +impl AsMut> for ArrayRef { + fn as_mut(&mut self) -> &mut RawRef { self } } // A4: &ArrayRef -AR-> &LayoutRef -impl AsRef> for ArrayRef -{ - fn as_ref(&self) -> &LayoutRef - { +impl AsRef> for ArrayRef { + fn as_ref(&self) -> &LayoutRef { self } } // A4: &mut ArrayRef -AM-> &mut LayoutRef -impl AsMut> for ArrayRef -{ - fn as_mut(&mut self) -> &mut LayoutRef - { +impl AsMut> for ArrayRef { + fn as_mut(&mut self) -> &mut LayoutRef { self } } // A5: &RawRef -AR-> &LayoutRef -impl AsRef> for RawRef -{ - fn as_ref(&self) -> &LayoutRef - { +impl AsRef> for RawRef { + fn as_ref(&self) -> &LayoutRef { self } } // A5: &mut RawRef -AM-> &mut LayoutRef -impl AsMut> for RawRef -{ - fn as_mut(&mut self) -> &mut LayoutRef - { +impl AsMut> for RawRef { + fn as_mut(&mut self) -> &mut LayoutRef { self } } // A6: &RawRef -AR-> &RawRef -impl AsRef> for RawRef -{ - fn as_ref(&self) -> &RawRef - { +impl AsRef> for RawRef { + fn as_ref(&self) -> &RawRef { self } } // A6: &mut RawRef -AM-> &mut RawRef -impl AsMut> for RawRef -{ - fn as_mut(&mut self) -> &mut RawRef - { +impl AsMut> for RawRef { + fn as_mut(&mut self) -> &mut RawRef { self } } // A6: &LayoutRef -AR-> &LayoutRef -impl AsRef> for LayoutRef -{ - fn as_ref(&self) -> &LayoutRef - { +impl AsRef> for LayoutRef { + fn as_ref(&self) -> &LayoutRef { self } } // A6: &mut LayoutRef -AR-> &mut LayoutRef -impl AsMut> for LayoutRef -{ - fn as_mut(&mut self) -> &mut LayoutRef - { +impl AsMut> for LayoutRef { + fn as_mut(&mut self) -> &mut LayoutRef { self } } @@ -294,10 +255,8 @@ impl AsMut> for LayoutRef /// impossible to read the data behind the pointer from a LayoutRef (this /// is a safety invariant that *must* be maintained), and therefore we can /// Clone and Copy as desired. -impl Clone for ArrayPartsSized -{ - fn clone(&self) -> Self - { +impl Clone for ArrayPartsSized { + fn clone(&self) -> Self { Self { dim: self.dim.clone(), strides: self.strides.clone(), @@ -310,28 +269,28 @@ impl Clone for ArrayPartsSized impl Copy for ArrayPartsSized {} impl Borrow> for ArrayBase -where S: RawData +where + S: RawData, { - fn borrow(&self) -> &RawRef - { + fn borrow(&self) -> &RawRef { self.as_ref() } } impl BorrowMut> for ArrayBase -where S: RawDataMut +where + S: RawDataMut, { - fn borrow_mut(&mut self) -> &mut RawRef - { + fn borrow_mut(&mut self) -> &mut RawRef { self.as_mut() } } impl Borrow> for ArrayBase -where S: Data +where + S: Data, { - fn borrow(&self) -> &ArrayRef - { + fn borrow(&self) -> &ArrayRef { self } } @@ -341,8 +300,7 @@ where S: DataMut, D: Dimension, { - fn borrow_mut(&mut self) -> &mut ArrayRef - { + fn borrow_mut(&mut self) -> &mut ArrayRef { self } } @@ -354,42 +312,39 @@ where { type Owned = Array; - fn to_owned(&self) -> Self::Owned - { + fn to_owned(&self) -> Self::Owned { self.to_owned() } - fn clone_into(&self, target: &mut Array) - { + fn clone_into(&self, target: &mut Array) { target.zip_mut_with(self, |tgt, src| tgt.clone_from(src)); } } /// Shortcuts for the various as_ref calls impl ArrayBase -where S: RawData +where + S: RawData, { /// Cheaply convert a reference to the array to an &LayoutRef - pub fn as_layout_ref(&self) -> &LayoutRef - { + pub fn as_layout_ref(&self) -> &LayoutRef { self.as_ref() } /// Cheaply and mutably convert a reference to the array to an &LayoutRef - pub fn as_layout_ref_mut(&mut self) -> &mut LayoutRef - { + pub fn as_layout_ref_mut(&mut self) -> &mut LayoutRef { self.as_mut() } /// Cheaply convert a reference to the array to an &RawRef - pub fn as_raw_ref(&self) -> &RawRef - { + pub fn as_raw_ref(&self) -> &RawRef { self.as_ref() } /// Cheaply and mutably convert a reference to the array to an &RawRef pub fn as_raw_ref_mut(&mut self) -> &mut RawRef - where S: RawDataMut + where + S: RawDataMut, { self.as_mut() } diff --git a/src/impl_special_element_types.rs b/src/impl_special_element_types.rs index 8b525e31..411398d2 100644 --- a/src/impl_special_element_types.rs +++ b/src/impl_special_element_types.rs @@ -32,8 +32,7 @@ where /// Note that for owned and shared ownership arrays, the promise must include all of the /// array's storage; it is for example possible to slice these in place, but that must /// only be done after all elements have been initialized. - pub unsafe fn assume_init(self) -> ArrayBase<>::Output, D> - { + pub unsafe fn assume_init(self) -> ArrayBase<>::Output, D> { let ArrayBase { data, parts: diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index dcf6527e..d4d2a280 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -17,7 +17,8 @@ use crate::{is_aligned, StrideShape}; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> -where D: Dimension +where + D: Dimension, { /// Create a read-only array view borrowing its data from a slice. /// @@ -45,14 +46,14 @@ where D: Dimension /// assert!(a.strides() == &[1, 4, 2]); /// ``` pub fn from_shape(shape: Sh, xs: &'a [A]) -> Result - where Sh: Into> + where + Sh: Into>, { // eliminate the type parameter Sh as soon as possible Self::from_shape_impl(shape.into(), xs) } - fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result - { + fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result { let dim = shape.dim; dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?; let strides = shape.strides.strides_for_dim(&dim); @@ -112,7 +113,8 @@ where D: Dimension /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *const A) -> Self - where Sh: Into> + where + Sh: Into>, { RawArrayView::from_shape_ptr(shape, ptr).deref_into_view() } @@ -120,7 +122,8 @@ where D: Dimension /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where D: Dimension +where + D: Dimension, { /// Create a read-write array view borrowing its data from a slice. /// @@ -148,14 +151,14 @@ where D: Dimension /// assert!(a.strides() == &[1, 4, 2]); /// ``` pub fn from_shape(shape: Sh, xs: &'a mut [A]) -> Result - where Sh: Into> + where + Sh: Into>, { // eliminate the type parameter Sh as soon as possible Self::from_shape_impl(shape.into(), xs) } - fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result - { + fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result { let dim = shape.dim; dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?; let strides = shape.strides.strides_for_dim(&dim); @@ -215,7 +218,8 @@ where D: Dimension /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *mut A) -> Self - where Sh: Into> + where + Sh: Into>, { RawArrayViewMut::from_shape_ptr(shape, ptr).deref_into_view_mut() } @@ -223,7 +227,8 @@ where D: Dimension /// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime /// outlived by `'a'`. pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D> - where 'a: 'b + where + 'a: 'b, { unsafe { ArrayViewMut::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } @@ -231,14 +236,14 @@ where D: Dimension /// Private array view methods impl ArrayView<'_, A, D> -where D: Dimension +where + D: Dimension, { /// Create a new `ArrayView` /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { if cfg!(debug_assertions) { assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); @@ -248,21 +253,20 @@ where D: Dimension /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline] - pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) } } impl ArrayViewMut<'_, A, D> -where D: Dimension +where + D: Dimension, { /// Create a new `ArrayView` /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { if cfg!(debug_assertions) { assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); @@ -274,8 +278,7 @@ where D: Dimension /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self - { + pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 54d7ed20..c9ab37e6 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -22,12 +22,14 @@ use crate::IndexLonger; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> -where D: Dimension +where + D: Dimension, { /// Convert the view into an `ArrayView<'b, A, D>` where `'b` is a lifetime /// outlived by `'a'`. pub fn reborrow<'b>(self) -> ArrayView<'b, A, D> - where 'a: 'b + where + 'a: 'b, { unsafe { ArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } @@ -37,8 +39,7 @@ where D: Dimension /// /// Note that while the method is similar to [`ArrayRef::as_slice()`], this method transfers /// the view's lifetime to the slice, so it is a bit more powerful. - pub fn to_slice(&self) -> Option<&'a [A]> - { + pub fn to_slice(&self) -> Option<&'a [A]> { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self.parts.ptr.as_ptr(), self.len())) } } else { @@ -52,8 +53,7 @@ where D: Dimension /// Note that while the method is similar to /// [`ArrayRef::as_slice_memory_order()`], this method transfers the view's /// lifetime to the slice, so it is a bit more powerful. - pub fn to_slice_memory_order(&self) -> Option<&'a [A]> - { + pub fn to_slice_memory_order(&self) -> Option<&'a [A]> { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.parts.dim, &self.parts.strides); unsafe { Some(slice::from_raw_parts(self.parts.ptr.sub(offset).as_ptr(), self.len())) } @@ -64,8 +64,7 @@ where D: Dimension /// Converts to a raw array view. #[inline] - pub(crate) fn into_raw_view(self) -> RawArrayView - { + pub(crate) fn into_raw_view(self) -> RawArrayView { unsafe { RawArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } } @@ -73,8 +72,7 @@ where D: Dimension /// Methods specific to `ArrayView0`. /// /// ***See also all methods for [`ArrayView`] and [`ArrayBase`]*** -impl<'a, A> ArrayView<'a, A, Ix0> -{ +impl<'a, A> ArrayView<'a, A, Ix0> { /// Consume the view and return a reference to the single element in the array. /// /// The lifetime of the returned reference matches the lifetime of the data @@ -92,8 +90,7 @@ impl<'a, A> ArrayView<'a, A, Ix0> /// let scalar: &Foo = view.into_scalar(); /// assert_eq!(scalar, &Foo); /// ``` - pub fn into_scalar(self) -> &'a A - { + pub fn into_scalar(self) -> &'a A { self.index(Ix0()) } } @@ -101,8 +98,7 @@ impl<'a, A> ArrayView<'a, A, Ix0> /// Methods specific to `ArrayViewMut0`. /// /// ***See also all methods for [`ArrayViewMut`] and [`ArrayBase`]*** -impl<'a, A> ArrayViewMut<'a, A, Ix0> -{ +impl<'a, A> ArrayViewMut<'a, A, Ix0> { /// Consume the mutable view and return a mutable reference to the single element in the array. /// /// The lifetime of the returned reference matches the lifetime of the data @@ -118,23 +114,22 @@ impl<'a, A> ArrayViewMut<'a, A, Ix0> /// assert_eq!(scalar, &7.); /// assert_eq!(array[()], 7.); /// ``` - pub fn into_scalar(self) -> &'a mut A - { + pub fn into_scalar(self) -> &'a mut A { self.index(Ix0()) } } /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where D: Dimension +where + D: Dimension, { /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// /// Note that while this is similar to [`ArrayBase::as_slice_mut()`], this method transfers the /// view's lifetime to the slice. - pub fn into_slice(self) -> Option<&'a mut [A]> - { + pub fn into_slice(self) -> Option<&'a mut [A]> { self.try_into_slice().ok() } @@ -144,8 +139,7 @@ where D: Dimension /// Note that while this is similar to /// [`ArrayBase::as_slice_memory_order_mut()`], this method transfers the /// view's lifetime to the slice. - pub fn into_slice_memory_order(self) -> Option<&'a mut [A]> - { + pub fn into_slice_memory_order(self) -> Option<&'a mut [A]> { self.try_into_slice_memory_order().ok() } @@ -155,8 +149,7 @@ where D: Dimension /// /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. - pub fn into_cell_view(self) -> ArrayView<'a, MathCell, D> - { + pub fn into_cell_view(self) -> ArrayView<'a, MathCell, D> { // safety: valid because // A and MathCell have the same representation // &'a mut T is interchangeable with &'a Cell -- see method Cell::from_mut in std @@ -180,8 +173,7 @@ where D: Dimension /// This method allows writing uninitialized data into the view, which could leave any /// original array that we borrow from in an inconsistent state. This is not allowed /// when using the resulting array view. - pub(crate) unsafe fn into_maybe_uninit(self) -> ArrayViewMut<'a, MaybeUninit, D> - { + pub(crate) unsafe fn into_maybe_uninit(self) -> ArrayViewMut<'a, MaybeUninit, D> { // Safe because: A and MaybeUninit have the same representation; // and we can go from initialized to (maybe) not unconditionally in terms of // representation. However, the user must be careful to not write uninit elements @@ -194,38 +186,37 @@ where D: Dimension /// Private raw array view methods impl RawArrayView -where D: Dimension +where + D: Dimension, { #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter - { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } } impl RawArrayViewMut -where D: Dimension +where + D: Dimension, { #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter - { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } } /// Methods for iterating over array views. impl<'a, A, D> ArrayView<'a, A, D> -where D: Dimension +where + D: Dimension, { #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter - { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> - { + pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> { ElementsBase::new(self) } @@ -234,7 +225,8 @@ where D: Dimension /// Unlike [ArrayRef::outer_iter], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_outer_iter(self) -> iter::AxisIter<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIter::new(self, Axis(0)) } @@ -243,8 +235,7 @@ where D: Dimension /// /// Unlike [ArrayRef::indexed_iter], this methods preserves the lifetime of the data, /// not the view itself. - pub fn into_indexed_iter(self) -> iter::IndexedIter<'a, A, D> - { + pub fn into_indexed_iter(self) -> iter::IndexedIter<'a, A, D> { iter::IndexedIter::new(self.into_elements_base()) } @@ -253,7 +244,8 @@ where D: Dimension /// Unlike [ArrayRef::axis_iter], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_iter(self, axis: Axis) -> iter::AxisIter<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIter::new(self, axis) } @@ -263,7 +255,8 @@ where D: Dimension /// Unlike [`ArrayRef::axis_chunks_iter`], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_chunks_iter(self, axis: Axis, chunk_size: usize) -> iter::AxisChunksIter<'a, A, D> - where D: RemoveAxis + where + D: RemoveAxis, { iter::AxisChunksIter::new(self, axis, chunk_size) } @@ -271,36 +264,32 @@ where D: Dimension /// Methods for iterating over mutable array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where D: Dimension +where + D: Dimension, { // Convert into a read-only view - pub(crate) fn into_view(self) -> ArrayView<'a, A, D> - { + pub(crate) fn into_view(self) -> ArrayView<'a, A, D> { unsafe { ArrayView::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } /// Converts to a mutable raw array view. - pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut - { + pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut { unsafe { RawArrayViewMut::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter - { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.parts.ptr, self.parts.dim, self.parts.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> - { + pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> { ElementsBaseMut::new(self) } /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Otherwise return self in the Err branch of the result. - pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> - { + pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> { if self.is_standard_layout() { unsafe { Ok(slice::from_raw_parts_mut(self.parts.ptr.as_ptr(), self.len())) } } else { @@ -310,8 +299,7 @@ where D: Dimension /// Return the array’s data as a slice, if it is contiguous. /// Otherwise return self in the Err branch of the result. - fn try_into_slice_memory_order(self) -> Result<&'a mut [A], Self> - { + fn try_into_slice_memory_order(self) -> Result<&'a mut [A], Self> { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.parts.dim, &self.parts.strides); unsafe { Ok(slice::from_raw_parts_mut(self.parts.ptr.sub(offset).as_ptr(), self.len())) } @@ -325,7 +313,8 @@ where D: Dimension /// Unlike [ArrayRef::outer_iter], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_outer_iter(self) -> iter::AxisIter<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIter::new(self.into_view(), Axis(0)) } @@ -334,8 +323,7 @@ where D: Dimension /// /// Unlike [ArrayRef::indexed_iter], this methods preserves the lifetime of the data, /// not the view itself. - pub fn into_indexed_iter(self) -> iter::IndexedIter<'a, A, D> - { + pub fn into_indexed_iter(self) -> iter::IndexedIter<'a, A, D> { iter::IndexedIter::new(self.into_view().into_elements_base()) } @@ -344,7 +332,8 @@ where D: Dimension /// Unlike [ArrayRef::axis_iter], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_iter(self, axis: Axis) -> iter::AxisIter<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIter::new(self.into_view(), axis) } @@ -354,7 +343,8 @@ where D: Dimension /// Unlike [`ArrayRef::axis_chunks_iter`], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_chunks_iter(self, axis: Axis, chunk_size: usize) -> iter::AxisChunksIter<'a, A, D> - where D: RemoveAxis + where + D: RemoveAxis, { iter::AxisChunksIter::new(self.into_view(), axis, chunk_size) } @@ -364,7 +354,8 @@ where D: Dimension /// Unlike [ArrayRef::outer_iter_mut], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_outer_iter_mut(self) -> iter::AxisIterMut<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIterMut::new(self, Axis(0)) } @@ -373,8 +364,7 @@ where D: Dimension /// /// Unlike [ArrayRef::indexed_iter_mut], this methods preserves the lifetime of the data, /// not the view itself. - pub fn into_indexed_iter_mut(self) -> iter::IndexedIterMut<'a, A, D> - { + pub fn into_indexed_iter_mut(self) -> iter::IndexedIterMut<'a, A, D> { iter::IndexedIterMut::new(self.into_elements_base()) } @@ -383,7 +373,8 @@ where D: Dimension /// Unlike [ArrayRef::axis_iter_mut], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_iter_mut(self, axis: Axis) -> iter::AxisIterMut<'a, A, D::Smaller> - where D: RemoveAxis + where + D: RemoveAxis, { AxisIterMut::new(self, axis) } @@ -393,7 +384,8 @@ where D: Dimension /// Unlike [`ArrayRef::axis_chunks_iter_mut`], this methods preserves the lifetime of the data, /// not the view itself. pub fn into_axis_chunks_iter_mut(self, axis: Axis, chunk_size: usize) -> iter::AxisChunksIterMut<'a, A, D> - where D: RemoveAxis + where + D: RemoveAxis, { iter::AxisChunksIterMut::new(self, axis, chunk_size) } diff --git a/src/impl_views/indexing.rs b/src/impl_views/indexing.rs index feadbd29..3a7b13e5 100644 --- a/src/impl_views/indexing.rs +++ b/src/impl_views/indexing.rs @@ -46,8 +46,7 @@ use crate::NdIndex; /// assert_eq!(long_life_ref, &0.); /// /// ``` -pub trait IndexLonger -{ +pub trait IndexLonger { /// The type of the reference to the element that is produced, including /// its lifetime. type Output; @@ -120,14 +119,12 @@ where /// /// **Panics** if index is out of bounds. #[track_caller] - fn index(self, index: I) -> &'a A - { + fn index(self, index: I) -> &'a A { debug_bounds_check!(self, index); unsafe { &*self.get_ptr(index).unwrap_or_else(|| array_out_of_bounds()) } } - fn get(self, index: I) -> Option<&'a A> - { + fn get(self, index: I) -> Option<&'a A> { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } @@ -142,8 +139,7 @@ where /// [1]: ArrayRef::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. - unsafe fn uget(self, index: I) -> &'a A - { + unsafe fn uget(self, index: I) -> &'a A { debug_bounds_check!(self, index); &*self .as_ptr() @@ -171,8 +167,7 @@ where /// /// **Panics** if index is out of bounds. #[track_caller] - fn index(mut self, index: I) -> &'a mut A - { + fn index(mut self, index: I) -> &'a mut A { debug_bounds_check!(self, index); unsafe { match self.get_mut_ptr(index) { @@ -190,8 +185,7 @@ where /// /// [1]: ArrayRef::get_mut /// - fn get(mut self, index: I) -> Option<&'a mut A> - { + fn get(mut self, index: I) -> Option<&'a mut A> { debug_bounds_check!(self, index); unsafe { match self.get_mut_ptr(index) { @@ -210,8 +204,7 @@ where /// [1]: ArrayRef::uget_mut /// /// **Note:** only unchecked for non-debug builds of ndarray. - unsafe fn uget(mut self, index: I) -> &'a mut A - { + unsafe fn uget(mut self, index: I) -> &'a mut A { debug_bounds_check!(self, index); &mut *self .as_mut_ptr() diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 42b12b15..da029386 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -12,7 +12,8 @@ use num_complex::Complex; /// Methods for read-only array views. impl ArrayView<'_, A, D> -where D: Dimension +where + D: Dimension, { /// Split the array view along `axis` and return one view strictly before the /// split and one view after the split. @@ -89,8 +90,7 @@ where D: Dimension /// ``` #[track_caller] #[inline] - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) - { + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { unsafe { let (left, right) = self.into_raw_view().split_at(axis, index); (left.deref_into_view(), right.deref_into_view()) @@ -99,7 +99,8 @@ where D: Dimension } impl<'a, T, D> ArrayView<'a, Complex, D> -where D: Dimension +where + D: Dimension, { /// Splits the view into views of the real and imaginary components of the /// elements. @@ -117,8 +118,7 @@ where D: Dimension /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); /// ``` - pub fn split_complex(self) -> Complex> - { + pub fn split_complex(self) -> Complex> { unsafe { let Complex { re, im } = self.into_raw_view().split_complex(); Complex { @@ -131,7 +131,8 @@ where D: Dimension /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where D: Dimension +where + D: Dimension, { /// Split the array view along `axis` and return one mutable view strictly /// before the split and one mutable view after the split. @@ -139,8 +140,7 @@ where D: Dimension /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] #[inline] - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) - { + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { unsafe { let (left, right) = self.into_raw_view_mut().split_at(axis, index); (left.deref_into_view_mut(), right.deref_into_view_mut()) @@ -166,14 +166,16 @@ where D: Dimension /// * if `D` is `IxDyn` and `info` does not match the number of array axes #[track_caller] pub fn multi_slice_move(self, info: M) -> M::Output - where M: MultiSliceArg<'a, A, D> + where + M: MultiSliceArg<'a, A, D>, { info.multi_slice_move(self) } } impl<'a, T, D> ArrayViewMut<'a, Complex, D> -where D: Dimension +where + D: Dimension, { /// Splits the view into views of the real and imaginary components of the /// elements. @@ -198,8 +200,7 @@ where D: Dimension /// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.)); /// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.)); /// ``` - pub fn split_complex(self) -> Complex> - { + pub fn split_complex(self) -> Complex> { unsafe { let Complex { re, im } = self.into_raw_view_mut().split_complex(); Complex { diff --git a/src/indexes.rs b/src/indexes.rs index 0fa2b50f..762815e0 100644 --- a/src/indexes.rs +++ b/src/indexes.rs @@ -18,8 +18,7 @@ use crate::{ArrayBase, Data}; /// /// Iterator element type is `D`. #[derive(Clone)] -pub struct IndicesIter -{ +pub struct IndicesIter { dim: D, index: Option, } @@ -29,7 +28,8 @@ pub struct IndicesIter /// *Note:* prefer higher order methods, arithmetic operations and /// non-indexed iteration before using indices. pub fn indices(shape: E) -> Indices -where E: IntoDimension +where + E: IntoDimension, { let dim = shape.into_dimension(); Indices { @@ -51,12 +51,12 @@ where } impl Iterator for IndicesIter -where D: Dimension +where + D: Dimension, { type Item = D::Pattern; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { let index = match self.index { None => return None, Some(ref ix) => ix.clone(), @@ -65,8 +65,7 @@ where D: Dimension Some(index.into_pattern()) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let l = match self.index { None => 0, Some(ref ix) => { @@ -84,7 +83,8 @@ where D: Dimension } fn fold(self, init: B, mut f: F) -> B - where F: FnMut(B, D::Pattern) -> B + where + F: FnMut(B, D::Pattern) -> B, { let IndicesIter { mut index, dim } = self; let ndim = dim.ndim(); @@ -112,12 +112,12 @@ where D: Dimension impl ExactSizeIterator for IndicesIter where D: Dimension {} impl IntoIterator for Indices -where D: Dimension +where + D: Dimension, { type Item = D::Pattern; type IntoIter = IndicesIter; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { let sz = self.dim.size(); let index = if sz != 0 { Some(self.start) } else { None }; IndicesIter { index, dim: self.dim } @@ -129,26 +129,26 @@ where D: Dimension /// `Indices` is an `NdProducer` that produces the indices of an array shape. #[derive(Copy, Clone, Debug)] pub struct Indices -where D: Dimension +where + D: Dimension, { start: D, dim: D, } #[derive(Copy, Clone, Debug)] -pub struct IndexPtr -{ +pub struct IndexPtr { index: D, } impl Offset for IndexPtr -where D: Dimension + Copy +where + D: Dimension + Copy, { // stride: The axis to increment type Stride = usize; - unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self - { + unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self { self.index[stride] += index; self } @@ -169,8 +169,7 @@ where D: Dimension + Copy // [0, 0, 0].stride_offset(1, 10) => [0, 10, 0] axis 1 is incremented by 10. // // .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0) -impl NdProducer for Indices -{ +impl NdProducer for Indices { type Item = D::Pattern; type Dim = D; type Ptr = IndexPtr; @@ -178,23 +177,19 @@ impl NdProducer for Indices private_impl! {} - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { self.dim } - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self.dim.equal(dim) } - fn as_ptr(&self) -> Self::Ptr - { + fn as_ptr(&self) -> Self::Ptr { IndexPtr { index: self.start } } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { if self.dim.ndim() <= 1 { Layout::one_dimensional() } else { @@ -202,31 +197,26 @@ impl NdProducer for Indices } } - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item - { + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { ptr.index.into_pattern() } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { let mut index = *i; index += &self.start; IndexPtr { index } } - fn stride_of(&self, axis: Axis) -> Self::Stride - { + fn stride_of(&self, axis: Axis) -> Self::Stride { axis.index() } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { 0 } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { let start_a = self.start; let mut start_b = start_a; let (a, b) = self.dim.split_at(axis, index); @@ -239,15 +229,15 @@ impl NdProducer for Indices /// /// Iterator element type is `D`. #[derive(Clone)] -pub struct IndicesIterF -{ +pub struct IndicesIterF { dim: D, index: D, has_remaining: bool, } pub fn indices_iter_f(shape: E) -> IndicesIterF -where E: IntoDimension +where + E: IntoDimension, { let dim = shape.into_dimension(); let zero = E::Dim::zeros(dim.ndim()); @@ -259,12 +249,12 @@ where E: IntoDimension } impl Iterator for IndicesIterF -where D: Dimension +where + D: Dimension, { type Item = D::Pattern; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if !self.has_remaining { None } else { @@ -274,8 +264,7 @@ where D: Dimension } } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { if !self.has_remaining { return (0, Some(0)); } @@ -294,14 +283,12 @@ where D: Dimension impl ExactSizeIterator for IndicesIterF where D: Dimension {} #[cfg(test)] -mod tests -{ +mod tests { use super::indices; use super::indices_iter_f; #[test] - fn test_indices_iter_c_size_hint() - { + fn test_indices_iter_c_size_hint() { let dim = (3, 4); let mut it = indices(dim).into_iter(); let mut len = dim.0 * dim.1; @@ -314,8 +301,7 @@ mod tests } #[test] - fn test_indices_iter_c_fold() - { + fn test_indices_iter_c_fold() { macro_rules! run_test { ($dim:expr) => { for num_consume in 0..3 { @@ -343,8 +329,7 @@ mod tests } #[test] - fn test_indices_iter_f_size_hint() - { + fn test_indices_iter_f_size_hint() { let dim = (3, 4); let mut it = indices_iter_f(dim); let mut len = dim.0 * dim.1; diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 178ead7e..eb1b29de 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -30,30 +30,26 @@ impl_ndproducer! { /// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. //#[derive(Debug)] -pub struct ExactChunks<'a, A, D> -{ +pub struct ExactChunks<'a, A, D> { base: RawArrayView, life: PhantomData<&'a A>, chunk: D, inner_strides: D, } -impl<'a, A, D: Dimension> ExactChunks<'a, A, D> -{ +impl<'a, A, D: Dimension> ExactChunks<'a, A, D> { /// Creates a new exact chunks producer. /// /// **Panics** if any chunk dimension is zero pub(crate) fn new(a: ArrayView<'a, A, D>, chunk: E) -> Self - where E: IntoDimension + where + E: IntoDimension, { let mut a = a.into_raw_view(); let chunk = chunk.into_dimension(); ndassert!( a.ndim() == chunk.ndim(), - concat!( - "Chunk dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Chunk dimension {} does not match array dimension {} ", "(with array of shape {:?})"), chunk.ndim(), a.ndim(), a.shape() @@ -80,8 +76,7 @@ where { type Item = ::Item; type IntoIter = ExactChunksIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { ExactChunksIter { iter: self.base.into_base_iter(), life: self.life, @@ -95,8 +90,7 @@ where /// /// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. -pub struct ExactChunksIter<'a, A, D> -{ +pub struct ExactChunksIter<'a, A, D> { iter: Baseiter, life: PhantomData<&'a A>, chunk: D, @@ -129,30 +123,26 @@ impl_ndproducer! { /// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. //#[derive(Debug)] -pub struct ExactChunksMut<'a, A, D> -{ +pub struct ExactChunksMut<'a, A, D> { base: RawArrayViewMut, life: PhantomData<&'a mut A>, chunk: D, inner_strides: D, } -impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> -{ +impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> { /// Creates a new exact chunks producer. /// /// **Panics** if any chunk dimension is zero pub(crate) fn new(a: ArrayViewMut<'a, A, D>, chunk: E) -> Self - where E: IntoDimension + where + E: IntoDimension, { let mut a = a.into_raw_view_mut(); let chunk = chunk.into_dimension(); ndassert!( a.ndim() == chunk.ndim(), - concat!( - "Chunk dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Chunk dimension {} does not match array dimension {} ", "(with array of shape {:?})"), chunk.ndim(), a.ndim(), a.shape() @@ -179,8 +169,7 @@ where { type Item = ::Item; type IntoIter = ExactChunksIterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { ExactChunksIterMut { iter: self.base.into_base_iter(), life: self.life, @@ -239,8 +228,7 @@ impl_iterator! { /// /// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. -pub struct ExactChunksIterMut<'a, A, D> -{ +pub struct ExactChunksIterMut<'a, A, D> { iter: Baseiter, life: PhantomData<&'a mut A>, chunk: D, diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs index cacafd2f..d3ffedcf 100644 --- a/src/iterators/into_iter.rs +++ b/src/iterators/into_iter.rs @@ -17,7 +17,8 @@ use crate::impl_owned_array::drop_unreachable_raw; /// By-value iterator for an array pub struct IntoIter -where D: Dimension +where + D: Dimension, { array_data: OwnedRepr, inner: Baseiter, @@ -30,11 +31,11 @@ where D: Dimension } impl IntoIter -where D: Dimension +where + D: Dimension, { /// Create a new by-value iterator that consumes `array` - pub(crate) fn new(array: Array) -> Self - { + pub(crate) fn new(array: Array) -> Self { unsafe { let array_head_ptr = array.parts.ptr; let mut array_data = array.data; @@ -54,35 +55,30 @@ where D: Dimension } } -impl Iterator for IntoIter -{ +impl Iterator for IntoIter { type Item = A; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { self.inner.next().map(|p| unsafe { p.as_ptr().read() }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.inner.size_hint() } } -impl ExactSizeIterator for IntoIter -{ - fn len(&self) -> usize - { +impl ExactSizeIterator for IntoIter { + fn len(&self) -> usize { self.inner.len() } } impl Drop for IntoIter -where D: Dimension +where + D: Dimension, { - fn drop(&mut self) - { + fn drop(&mut self) { if !self.has_unreachable_elements || mem::size_of::() == 0 || !mem::needs_drop::() { return; } @@ -93,21 +89,25 @@ where D: Dimension unsafe { let data_ptr = self.array_data.as_nonnull_mut(); let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), self.inner.strides.clone()); - debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}", - self.data_len, self.inner.dim.size()); + debug_assert!( + self.inner.dim.size() < self.data_len, + "data_len {} and dim size {}", + self.data_len, + self.inner.dim.size() + ); drop_unreachable_raw(view, data_ptr, self.data_len); } } } impl IntoIterator for Array -where D: Dimension +where + D: Dimension, { type Item = A; type IntoIter = IntoIter; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { IntoIter::new(self) } } @@ -120,8 +120,7 @@ where type Item = A; type IntoIter = IntoIter; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { IntoIter::new(self.into_owned()) } } @@ -134,8 +133,7 @@ where type Item = A; type IntoIter = IntoIter; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { IntoIter::new(self.into_owned()) } } diff --git a/src/iterators/iter.rs b/src/iterators/iter.rs index 478987ee..6e84728f 100644 --- a/src/iterators/iter.rs +++ b/src/iterators/iter.rs @@ -9,23 +9,7 @@ pub use crate::dimension::Axes; pub use crate::indexes::{Indices, IndicesIter}; pub use crate::iterators::{ - AxisChunksIter, - AxisChunksIterMut, - AxisIter, - AxisIterMut, - AxisWindows, - ExactChunks, - ExactChunksIter, - ExactChunksIterMut, - ExactChunksMut, - IndexedIter, - IndexedIterMut, - IntoIter, - Iter, - IterMut, - Lanes, - LanesIter, - LanesIterMut, - LanesMut, - Windows, + AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, AxisWindows, ExactChunks, ExactChunksIter, + ExactChunksIterMut, ExactChunksMut, IndexedIter, IndexedIterMut, IntoIter, Iter, IterMut, Lanes, LanesIter, + LanesIterMut, LanesMut, Windows, }; diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 9fd39607..3e2be107 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -25,17 +25,16 @@ impl_ndproducer! { /// See [`.lanes()`](crate::ArrayRef::lanes) /// for more information. -pub struct Lanes<'a, A, D> -{ +pub struct Lanes<'a, A, D> { base: ArrayView<'a, A, D>, inner_len: Ix, inner_stride: Ixs, } -impl<'a, A, D: Dimension> Lanes<'a, A, D> -{ +impl<'a, A, D: Dimension> Lanes<'a, A, D> { pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self - where Di: Dimension + where + Di: Dimension, { let ndim = v.ndim(); let len; @@ -77,12 +76,12 @@ impl_ndproducer! { } impl<'a, A, D> IntoIterator for Lanes<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ::Item; type IntoIter = LanesIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { LanesIter { iter: self.base.into_base_iter(), inner_len: self.inner_len, @@ -94,17 +93,16 @@ where D: Dimension /// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. -pub struct LanesMut<'a, A, D> -{ +pub struct LanesMut<'a, A, D> { base: ArrayViewMut<'a, A, D>, inner_len: Ix, inner_stride: Ixs, } -impl<'a, A, D: Dimension> LanesMut<'a, A, D> -{ +impl<'a, A, D: Dimension> LanesMut<'a, A, D> { pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self - where Di: Dimension + where + Di: Dimension, { let ndim = v.ndim(); let len; @@ -128,12 +126,12 @@ impl<'a, A, D: Dimension> LanesMut<'a, A, D> } impl<'a, A, D> IntoIterator for LanesMut<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ::Item; type IntoIter = LanesIterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { LanesIterMut { iter: self.base.into_base_iter(), inner_len: self.inner_len, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index abca3579..e35bacc3 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -40,22 +40,19 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; /// /// Iterator element type is `NonNull`. #[derive(Debug)] -pub struct Baseiter -{ +pub struct Baseiter { ptr: NonNull, dim: D, strides: D, index: Option, } -impl Baseiter -{ +impl Baseiter { /// Creating a Baseiter is unsafe because shape and stride parameters need /// to be correct to avoid performing an unsafe pointer offset while /// iterating. #[inline] - pub unsafe fn new(ptr: NonNull, len: D, stride: D) -> Baseiter - { + pub unsafe fn new(ptr: NonNull, len: D, stride: D) -> Baseiter { Baseiter { ptr, index: len.first_index(), @@ -65,27 +62,25 @@ impl Baseiter } } -impl Iterator for Baseiter -{ +impl Iterator for Baseiter { type Item = NonNull; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { let index = self.index.take()?; let offset = D::stride_offset(&index, &self.strides); self.index = self.dim.next_for(index); unsafe { Some(self.ptr.offset(offset)) } } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let len = self.len(); (len, Some(len)) } fn fold(mut self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { let ndim = self.dim.ndim(); debug_assert_ne!(ndim, 0); @@ -111,10 +106,8 @@ impl Iterator for Baseiter } } -impl ExactSizeIterator for Baseiter -{ - fn len(&self) -> usize - { +impl ExactSizeIterator for Baseiter { + fn len(&self) -> usize { match self.index { None => 0, Some(ref ix) => { @@ -131,11 +124,9 @@ impl ExactSizeIterator for Baseiter } } -impl DoubleEndedIterator for Baseiter -{ +impl DoubleEndedIterator for Baseiter { #[inline] - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { let index = self.index?; self.dim[0] -= 1; let offset = Ix1::stride_offset(&self.dim, &self.strides); @@ -146,8 +137,7 @@ impl DoubleEndedIterator for Baseiter unsafe { Some(self.ptr.offset(offset)) } } - fn nth_back(&mut self, n: usize) -> Option - { + fn nth_back(&mut self, n: usize) -> Option { let index = self.index?; let len = self.dim[0] - index[0]; if n < len { @@ -164,7 +154,8 @@ impl DoubleEndedIterator for Baseiter } fn rfold(mut self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { let mut accum = init; if let Some(index) = self.index { @@ -207,10 +198,8 @@ clone_bounds!( } ); -impl<'a, A, D: Dimension> ElementsBase<'a, A, D> -{ - pub fn new(v: ArrayView<'a, A, D>) -> Self - { +impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { + pub fn new(v: ArrayView<'a, A, D>) -> Self { ElementsBase { inner: v.into_base_iter(), life: PhantomData, @@ -218,47 +207,44 @@ impl<'a, A, D: Dimension> ElementsBase<'a, A, D> } } -impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> { type Item = &'a A; #[inline] - fn next(&mut self) -> Option<&'a A> - { + fn next(&mut self) -> Option<&'a A> { self.inner.next().map(|p| unsafe { p.as_ref() }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { unsafe { self.inner.fold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } -impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> -{ +impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> { #[inline] - fn next_back(&mut self) -> Option<&'a A> - { + fn next_back(&mut self) -> Option<&'a A> { self.inner.next_back().map(|p| unsafe { p.as_ref() }) } fn rfold(self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } impl ExactSizeIterator for ElementsBase<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.inner.len() } } @@ -291,10 +277,10 @@ clone_bounds!( ); impl<'a, A, D> Iter<'a, A, D> -where D: Dimension +where + D: Dimension, { - pub(crate) fn new(self_: ArrayView<'a, A, D>) -> Self - { + pub(crate) fn new(self_: ArrayView<'a, A, D>) -> Self { Iter { inner: if let Some(slc) = self_.to_slice() { ElementsRepr::Slice(slc.iter()) @@ -306,10 +292,10 @@ where D: Dimension } impl<'a, A, D> IterMut<'a, A, D> -where D: Dimension +where + D: Dimension, { - pub(crate) fn new(self_: ArrayViewMut<'a, A, D>) -> Self - { + pub(crate) fn new(self_: ArrayViewMut<'a, A, D>) -> Self { IterMut { inner: match self_.try_into_slice() { Ok(x) => ElementsRepr::Slice(x.iter_mut()), @@ -320,8 +306,7 @@ where D: Dimension } #[derive(Clone, Debug)] -pub enum ElementsRepr -{ +pub enum ElementsRepr { Slice(S), Counted(C), } @@ -332,15 +317,13 @@ pub enum ElementsRepr /// /// See [`.iter()`](crate::ArrayRef::iter) for more information. #[derive(Debug)] -pub struct Iter<'a, A, D> -{ +pub struct Iter<'a, A, D> { inner: ElementsRepr, ElementsBase<'a, A, D>>, } /// Counted read only iterator #[derive(Debug)] -pub struct ElementsBase<'a, A, D> -{ +pub struct ElementsBase<'a, A, D> { inner: Baseiter, life: PhantomData<&'a A>, } @@ -351,8 +334,7 @@ pub struct ElementsBase<'a, A, D> /// /// See [`.iter_mut()`](crate::ArrayRef::iter_mut) for more information. #[derive(Debug)] -pub struct IterMut<'a, A, D> -{ +pub struct IterMut<'a, A, D> { inner: ElementsRepr, ElementsBaseMut<'a, A, D>>, } @@ -360,16 +342,13 @@ pub struct IterMut<'a, A, D> /// /// Iterator element type is `&'a mut A`. #[derive(Debug)] -pub struct ElementsBaseMut<'a, A, D> -{ +pub struct ElementsBaseMut<'a, A, D> { inner: Baseiter, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> -{ - pub fn new(v: ArrayViewMut<'a, A, D>) -> Self - { +impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { + pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { ElementsBaseMut { inner: v.into_base_iter(), life: PhantomData, @@ -388,130 +367,127 @@ pub struct IndexedIter<'a, A, D>(ElementsBase<'a, A, D>); pub struct IndexedIterMut<'a, A, D>(ElementsBaseMut<'a, A, D>); impl<'a, A, D> IndexedIter<'a, A, D> -where D: Dimension +where + D: Dimension, { - pub(crate) fn new(x: ElementsBase<'a, A, D>) -> Self - { + pub(crate) fn new(x: ElementsBase<'a, A, D>) -> Self { IndexedIter(x) } } impl<'a, A, D> IndexedIterMut<'a, A, D> -where D: Dimension +where + D: Dimension, { - pub(crate) fn new(x: ElementsBaseMut<'a, A, D>) -> Self - { + pub(crate) fn new(x: ElementsBaseMut<'a, A, D>) -> Self { IndexedIterMut(x) } } -impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> { type Item = &'a A; #[inline] - fn next(&mut self) -> Option<&'a A> - { + fn next(&mut self) -> Option<&'a A> { either_mut!(self.inner, iter => iter.next()) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { either!(self.inner, ref iter => iter.size_hint()) } fn fold(self, init: Acc, g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { either!(self.inner, iter => iter.fold(init, g)) } - fn nth(&mut self, n: usize) -> Option - { + fn nth(&mut self, n: usize) -> Option { either_mut!(self.inner, iter => iter.nth(n)) } fn collect(self) -> B - where B: FromIterator + where + B: FromIterator, { either!(self.inner, iter => iter.collect()) } fn all(&mut self, f: F) -> bool - where F: FnMut(Self::Item) -> bool + where + F: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.all(f)) } fn any(&mut self, f: F) -> bool - where F: FnMut(Self::Item) -> bool + where + F: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.any(f)) } fn find

(&mut self, predicate: P) -> Option - where P: FnMut(&Self::Item) -> bool + where + P: FnMut(&Self::Item) -> bool, { either_mut!(self.inner, iter => iter.find(predicate)) } fn find_map(&mut self, f: F) -> Option - where F: FnMut(Self::Item) -> Option + where + F: FnMut(Self::Item) -> Option, { either_mut!(self.inner, iter => iter.find_map(f)) } - fn count(self) -> usize - { + fn count(self) -> usize { either!(self.inner, iter => iter.count()) } - fn last(self) -> Option - { + fn last(self) -> Option { either!(self.inner, iter => iter.last()) } fn position

(&mut self, predicate: P) -> Option - where P: FnMut(Self::Item) -> bool + where + P: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.position(predicate)) } } -impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> -{ +impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> { #[inline] - fn next_back(&mut self) -> Option<&'a A> - { + fn next_back(&mut self) -> Option<&'a A> { either_mut!(self.inner, iter => iter.next_back()) } - fn nth_back(&mut self, n: usize) -> Option<&'a A> - { + fn nth_back(&mut self, n: usize) -> Option<&'a A> { either_mut!(self.inner, iter => iter.nth_back(n)) } fn rfold(self, init: Acc, g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { either!(self.inner, iter => iter.rfold(init, g)) } } impl ExactSizeIterator for Iter<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { either!(self.inner, ref iter => iter.len()) } } -impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> { type Item = (D::Pattern, &'a A); #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { let index = match self.0.inner.index { None => return None, Some(ref ix) => ix.clone(), @@ -522,138 +498,134 @@ impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> } } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.0.size_hint() } } impl ExactSizeIterator for IndexedIter<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.inner.len() } } -impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> { type Item = &'a mut A; #[inline] - fn next(&mut self) -> Option<&'a mut A> - { + fn next(&mut self) -> Option<&'a mut A> { either_mut!(self.inner, iter => iter.next()) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { either!(self.inner, ref iter => iter.size_hint()) } fn fold(self, init: Acc, g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { either!(self.inner, iter => iter.fold(init, g)) } - fn nth(&mut self, n: usize) -> Option - { + fn nth(&mut self, n: usize) -> Option { either_mut!(self.inner, iter => iter.nth(n)) } fn collect(self) -> B - where B: FromIterator + where + B: FromIterator, { either!(self.inner, iter => iter.collect()) } fn all(&mut self, f: F) -> bool - where F: FnMut(Self::Item) -> bool + where + F: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.all(f)) } fn any(&mut self, f: F) -> bool - where F: FnMut(Self::Item) -> bool + where + F: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.any(f)) } fn find

(&mut self, predicate: P) -> Option - where P: FnMut(&Self::Item) -> bool + where + P: FnMut(&Self::Item) -> bool, { either_mut!(self.inner, iter => iter.find(predicate)) } fn find_map(&mut self, f: F) -> Option - where F: FnMut(Self::Item) -> Option + where + F: FnMut(Self::Item) -> Option, { either_mut!(self.inner, iter => iter.find_map(f)) } - fn count(self) -> usize - { + fn count(self) -> usize { either!(self.inner, iter => iter.count()) } - fn last(self) -> Option - { + fn last(self) -> Option { either!(self.inner, iter => iter.last()) } fn position

(&mut self, predicate: P) -> Option - where P: FnMut(Self::Item) -> bool + where + P: FnMut(Self::Item) -> bool, { either_mut!(self.inner, iter => iter.position(predicate)) } } -impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> -{ +impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> { #[inline] - fn next_back(&mut self) -> Option<&'a mut A> - { + fn next_back(&mut self) -> Option<&'a mut A> { either_mut!(self.inner, iter => iter.next_back()) } - fn nth_back(&mut self, n: usize) -> Option<&'a mut A> - { + fn nth_back(&mut self, n: usize) -> Option<&'a mut A> { either_mut!(self.inner, iter => iter.nth_back(n)) } fn rfold(self, init: Acc, g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { either!(self.inner, iter => iter.rfold(init, g)) } } impl ExactSizeIterator for IterMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { either!(self.inner, ref iter => iter.len()) } } -impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> { type Item = &'a mut A; #[inline] - fn next(&mut self) -> Option<&'a mut A> - { + fn next(&mut self) -> Option<&'a mut A> { self.inner.next().map(|mut p| unsafe { p.as_mut() }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { unsafe { self.inner @@ -662,16 +634,15 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> } } -impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> -{ +impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> { #[inline] - fn next_back(&mut self) -> Option<&'a mut A> - { + fn next_back(&mut self) -> Option<&'a mut A> { self.inner.next_back().map(|mut p| unsafe { p.as_mut() }) } fn rfold(self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, Self::Item) -> Acc + where + G: FnMut(Acc, Self::Item) -> Acc, { unsafe { self.inner @@ -681,20 +652,18 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> } impl ExactSizeIterator for ElementsBaseMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.inner.len() } } -impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> -{ +impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> { type Item = (D::Pattern, &'a mut A); #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { let index = match self.0.inner.index { None => return None, Some(ref ix) => ix.clone(), @@ -705,17 +674,16 @@ impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> } } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.0.size_hint() } } impl ExactSizeIterator for IndexedIterMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.0.inner.len() } } @@ -724,8 +692,7 @@ where D: Dimension /// each lane along that axis. /// /// See [`.lanes()`](crate::ArrayRef::lanes) for more information. -pub struct LanesIter<'a, A, D> -{ +pub struct LanesIter<'a, A, D> { inner_len: Ix, inner_stride: Ixs, iter: Baseiter, @@ -745,35 +712,32 @@ clone_bounds!( ); impl<'a, A, D> Iterator for LanesIter<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ArrayView<'a, A, Ix1>; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { self.iter .next() .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } } impl ExactSizeIterator for LanesIter<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.iter.len() } } -impl DoubleEndedIterator for LanesIter<'_, A, Ix1> -{ - fn next_back(&mut self) -> Option - { +impl DoubleEndedIterator for LanesIter<'_, A, Ix1> { + fn next_back(&mut self) -> Option { self.iter .next_back() .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) @@ -788,8 +752,7 @@ impl DoubleEndedIterator for LanesIter<'_, A, Ix1> /// /// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. -pub struct LanesIterMut<'a, A, D> -{ +pub struct LanesIterMut<'a, A, D> { inner_len: Ix, inner_stride: Ixs, iter: Baseiter, @@ -797,35 +760,32 @@ pub struct LanesIterMut<'a, A, D> } impl<'a, A, D> Iterator for LanesIterMut<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ArrayViewMut<'a, A, Ix1>; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { self.iter .next() .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } } impl ExactSizeIterator for LanesIterMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.iter.len() } } -impl DoubleEndedIterator for LanesIterMut<'_, A, Ix1> -{ - fn next_back(&mut self) -> Option - { +impl DoubleEndedIterator for LanesIterMut<'_, A, Ix1> { + fn next_back(&mut self) -> Option { self.iter .next_back() .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) @@ -833,8 +793,7 @@ impl DoubleEndedIterator for LanesIterMut<'_, A, Ix1> } #[derive(Debug)] -pub struct AxisIterCore -{ +pub struct AxisIterCore { /// Index along the axis of the value of `.next()`, relative to the start /// of the axis. index: Ix, @@ -865,8 +824,7 @@ clone_bounds!( } ); -impl AxisIterCore -{ +impl AxisIterCore { /// Constructs a new iterator over the specified axis. fn new(v: ArrayBase, axis: Axis) -> Self where @@ -884,15 +842,8 @@ impl AxisIterCore } #[inline] - unsafe fn offset(&self, index: usize) -> *mut A - { - debug_assert!( - index < self.end, - "index={}, end={}, stride={}", - index, - self.end, - self.stride - ); + unsafe fn offset(&self, index: usize) -> *mut A { + debug_assert!(index < self.end, "index={}, end={}, stride={}", index, self.end, self.stride); self.ptr.offset(index as isize * self.stride) } @@ -904,8 +855,7 @@ impl AxisIterCore /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. #[track_caller] - fn split_at(self, index: usize) -> (Self, Self) - { + fn split_at(self, index: usize) -> (Self, Self) { assert!(index <= self.len()); let mid = self.index + index; let left = AxisIterCore { @@ -929,27 +879,25 @@ impl AxisIterCore /// Does the same thing as `.next()` but also returns the index of the item /// relative to the start of the axis. - fn next_with_index(&mut self) -> Option<(usize, *mut A)> - { + fn next_with_index(&mut self) -> Option<(usize, *mut A)> { let index = self.index; self.next().map(|ptr| (index, ptr)) } /// Does the same thing as `.next_back()` but also returns the index of the /// item relative to the start of the axis. - fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> - { + fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> { self.next_back().map(|ptr| (self.end, ptr)) } } impl Iterator for AxisIterCore -where D: Dimension +where + D: Dimension, { type Item = *mut A; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if self.index >= self.end { None } else { @@ -959,18 +907,17 @@ where D: Dimension } } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let len = self.len(); (len, Some(len)) } } impl DoubleEndedIterator for AxisIterCore -where D: Dimension +where + D: Dimension, { - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { if self.index >= self.end { None } else { @@ -982,10 +929,10 @@ where D: Dimension } impl ExactSizeIterator for AxisIterCore -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.end - self.index } } @@ -1005,8 +952,7 @@ where D: Dimension /// or [`.axis_iter()`](crate::ArrayRef::axis_iter) /// for more information. #[derive(Debug)] -pub struct AxisIter<'a, A, D> -{ +pub struct AxisIter<'a, A, D> { iter: AxisIterCore, life: PhantomData<&'a A>, } @@ -1021,11 +967,11 @@ clone_bounds!( } ); -impl<'a, A, D: Dimension> AxisIter<'a, A, D> -{ +impl<'a, A, D: Dimension> AxisIter<'a, A, D> { /// Creates a new iterator over the specified axis. pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self - where Di: RemoveAxis + where + Di: RemoveAxis, { AxisIter { iter: AxisIterCore::new(v, axis), @@ -1041,8 +987,7 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. #[track_caller] - pub fn split_at(self, index: usize) -> (Self, Self) - { + pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( AxisIter { @@ -1058,35 +1003,34 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> } impl<'a, A, D> Iterator for AxisIter<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ArrayView<'a, A, D>; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } } impl DoubleEndedIterator for AxisIter<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) }) } } impl ExactSizeIterator for AxisIter<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.iter.len() } } @@ -1105,17 +1049,16 @@ where D: Dimension /// See [`.outer_iter_mut()`](crate::ArrayRef::outer_iter_mut) /// or [`.axis_iter_mut()`](crate::ArrayRef::axis_iter_mut) /// for more information. -pub struct AxisIterMut<'a, A, D> -{ +pub struct AxisIterMut<'a, A, D> { iter: AxisIterCore, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> -{ +impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { /// Creates a new iterator over the specified axis. pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self - where Di: RemoveAxis + where + Di: RemoveAxis, { AxisIterMut { iter: AxisIterCore::new(v, axis), @@ -1131,8 +1074,7 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. #[track_caller] - pub fn split_at(self, index: usize) -> (Self, Self) - { + pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( AxisIterMut { @@ -1148,58 +1090,53 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> } impl<'a, A, D> Iterator for AxisIterMut<'a, A, D> -where D: Dimension +where + D: Dimension, { type Item = ArrayViewMut<'a, A, D>; - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) }) } - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } } impl DoubleEndedIterator for AxisIterMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) }) } } impl ExactSizeIterator for AxisIterMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn len(&self) -> usize - { + fn len(&self) -> usize { self.iter.len() } } -impl NdProducer for AxisIter<'_, A, D> -{ +impl NdProducer for AxisIter<'_, A, D> { type Item = ::Item; type Dim = Ix1; type Ptr = *mut A; type Stride = isize; - fn layout(&self) -> crate::Layout - { + fn layout(&self) -> crate::Layout { crate::Layout::one_dimensional() } - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { Ix1(self.len()) } - fn as_ptr(&self) -> Self::Ptr - { + fn as_ptr(&self) -> Self::Ptr { if self.len() > 0 { // `self.iter.index` is guaranteed to be in-bounds if any of the // iterator remains (i.e. if `self.len() > 0`). @@ -1212,53 +1149,44 @@ impl NdProducer for AxisIter<'_, A, D> } } - fn contiguous_stride(&self) -> isize - { + fn contiguous_stride(&self) -> isize { self.iter.stride } - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item - { + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { ArrayView::new_(ptr, self.iter.inner_dim.clone(), self.iter.inner_strides.clone()) } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { self.iter.offset(self.iter.index + i[0]) } - fn stride_of(&self, _axis: Axis) -> isize - { + fn stride_of(&self, _axis: Axis) -> isize { self.contiguous_stride() } - fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) { self.split_at(index) } private_impl! {} } -impl NdProducer for AxisIterMut<'_, A, D> -{ +impl NdProducer for AxisIterMut<'_, A, D> { type Item = ::Item; type Dim = Ix1; type Ptr = *mut A; type Stride = isize; - fn layout(&self) -> crate::Layout - { + fn layout(&self) -> crate::Layout { crate::Layout::one_dimensional() } - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { Ix1(self.len()) } - fn as_ptr(&self) -> Self::Ptr - { + fn as_ptr(&self) -> Self::Ptr { if self.len() > 0 { // `self.iter.index` is guaranteed to be in-bounds if any of the // iterator remains (i.e. if `self.len() > 0`). @@ -1271,28 +1199,23 @@ impl NdProducer for AxisIterMut<'_, A, D> } } - fn contiguous_stride(&self) -> isize - { + fn contiguous_stride(&self) -> isize { self.iter.stride } - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item - { + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { ArrayViewMut::new_(ptr, self.iter.inner_dim.clone(), self.iter.inner_strides.clone()) } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { self.iter.offset(self.iter.index + i[0]) } - fn stride_of(&self, _axis: Axis) -> isize - { + fn stride_of(&self, _axis: Axis) -> isize { self.contiguous_stride() } - fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) { self.split_at(index) } @@ -1309,8 +1232,7 @@ impl NdProducer for AxisIterMut<'_, A, D> /// Iterator element type is `ArrayView<'a, A, D>`. /// /// See [`.axis_chunks_iter()`](crate::ArrayRef::axis_chunks_iter) for more information. -pub struct AxisChunksIter<'a, A, D> -{ +pub struct AxisChunksIter<'a, A, D> { iter: AxisIterCore, /// Index of the partial chunk (the chunk smaller than the specified chunk /// size due to the axis length not being evenly divisible). If the axis @@ -1343,9 +1265,9 @@ clone_bounds!( /// /// **Panics** if `size == 0`. #[track_caller] -fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: usize) - -> (AxisIterCore, usize, D) -{ +fn chunk_iter_parts( + v: ArrayView<'_, A, D>, axis: Axis, size: usize, +) -> (AxisIterCore, usize, D) { assert_ne!(size, 0, "Chunk size must be nonzero."); let axis_len = v.len_of(axis); let n_whole_chunks = axis_len / size; @@ -1382,10 +1304,8 @@ fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: u (iter, partial_chunk_index, partial_chunk_dim) } -impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> -{ - pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self - { +impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { + pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size); AxisChunksIter { iter, @@ -1404,21 +1324,9 @@ macro_rules! chunk_iter_impl { { fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> { if index != self.partial_chunk_index { - unsafe { - $array::new_( - ptr, - self.iter.inner_dim.clone(), - self.iter.inner_strides.clone(), - ) - } + unsafe { $array::new_(ptr, self.iter.inner_dim.clone(), self.iter.inner_strides.clone()) } } else { - unsafe { - $array::new_( - ptr, - self.partial_chunk_dim.clone(), - self.iter.inner_strides.clone(), - ) - } + unsafe { $array::new_(ptr, self.partial_chunk_dim.clone(), self.iter.inner_strides.clone()) } } } @@ -1492,18 +1400,15 @@ macro_rules! chunk_iter_impl { /// /// See [`.axis_chunks_iter_mut()`](crate::ArrayRef::axis_chunks_iter_mut) /// for more information. -pub struct AxisChunksIterMut<'a, A, D> -{ +pub struct AxisChunksIterMut<'a, A, D> { iter: AxisIterCore, partial_chunk_index: usize, partial_chunk_dim: D, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> -{ - pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self - { +impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { + pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v.into_view(), axis, size); AxisChunksIterMut { iter, @@ -1563,7 +1468,8 @@ unsafe impl TrustedIterator for IntoIter where D: Dimension {} /// Like Iterator::collect, but only for trusted length iterators pub fn to_vec(iter: I) -> Vec -where I: TrustedIterator + ExactSizeIterator +where + I: TrustedIterator + ExactSizeIterator, { to_vec_mapped(iter, |x| x) } diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index e6fccce4..380cf8ee 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -11,18 +11,17 @@ use crate::Slice; /// /// See [`.windows()`](crate::ArrayRef::windows) for more /// information. -pub struct Windows<'a, A, D> -{ +pub struct Windows<'a, A, D> { base: RawArrayView, life: PhantomData<&'a A>, window: D, strides: D, } -impl<'a, A, D: Dimension> Windows<'a, A, D> -{ +impl<'a, A, D: Dimension> Windows<'a, A, D> { pub(crate) fn new(a: ArrayView<'a, A, D>, window_size: E) -> Self - where E: IntoDimension + where + E: IntoDimension, { let window = window_size.into_dimension(); let ndim = window.ndim(); @@ -34,7 +33,8 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> } pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self - where E: IntoDimension + where + E: IntoDimension, { let window = window_size.into_dimension(); @@ -78,8 +78,7 @@ where { type Item = ::Item; type IntoIter = WindowsIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { WindowsIter { iter: self.base.into_base_iter(), life: self.life, @@ -93,8 +92,7 @@ where /// /// See [`.windows()`](crate::ArrayRef::windows) for more /// information. -pub struct WindowsIter<'a, A, D> -{ +pub struct WindowsIter<'a, A, D> { iter: Baseiter, life: PhantomData<&'a A>, window: D, @@ -131,18 +129,15 @@ send_sync_read_only!(WindowsIter); /// /// See [`.axis_windows()`](crate::ArrayRef::axis_windows) for more /// information. -pub struct AxisWindows<'a, A, D> -{ +pub struct AxisWindows<'a, A, D> { base: ArrayView<'a, A, D>, axis_idx: usize, window: D, strides: D, } -impl<'a, A, D: Dimension> AxisWindows<'a, A, D> -{ - pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize, stride_size: usize) -> Self - { +impl<'a, A, D: Dimension> AxisWindows<'a, A, D> { + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize, stride_size: usize) -> Self { let window_strides = a.parts.strides.clone(); let axis_idx = axis.index(); @@ -164,53 +159,44 @@ impl<'a, A, D: Dimension> AxisWindows<'a, A, D> } } -impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> -{ +impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> { type Item = ArrayView<'a, A, D>; type Dim = Ix1; type Ptr = *mut A; type Stride = isize; - fn raw_dim(&self) -> Ix1 - { + fn raw_dim(&self) -> Ix1 { Ix1(self.base.raw_dim()[self.axis_idx]) } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { self.base.layout() } - fn as_ptr(&self) -> *mut A - { + fn as_ptr(&self) -> *mut A { self.base.as_ptr() as *mut _ } - fn contiguous_stride(&self) -> isize - { + fn contiguous_stride(&self) -> isize { self.base.contiguous_stride() } - unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item - { + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item { ArrayView::new_(ptr, self.window.clone(), self.strides.clone()) } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { let mut d = D::zeros(self.base.ndim()); d[self.axis_idx] = i[0]; self.base.uget_ptr(&d) } - fn stride_of(&self, axis: Axis) -> isize - { + fn stride_of(&self, axis: Axis) -> isize { assert_eq!(axis, Axis(0)); self.base.stride_of(Axis(self.axis_idx)) } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { assert_eq!(axis, Axis(0)); let (a, b) = self.base.split_at(Axis(self.axis_idx), index); ( @@ -229,7 +215,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> ) } - private_impl!{} + private_impl! {} } impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D> @@ -239,8 +225,7 @@ where { type Item = ::Item; type IntoIter = WindowsIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter - { + fn into_iter(self) -> Self::IntoIter { WindowsIter { iter: self.base.into_base_iter(), life: PhantomData, @@ -252,14 +237,12 @@ where /// build the base array of the `Windows` and `AxisWindows` structs fn build_base(a: ArrayView, window: D, strides: D) -> ArrayView -where D: Dimension +where + D: Dimension, { ndassert!( a.ndim() == window.ndim(), - concat!( - "Window dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Window dimension {} does not match array dimension {} ", "(with array of shape {:?})"), window.ndim(), a.ndim(), a.shape() @@ -267,10 +250,7 @@ where D: Dimension ndassert!( a.ndim() == strides.ndim(), - concat!( - "Stride dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Stride dimension {} does not match array dimension {} ", "(with array of shape {:?})"), strides.ndim(), a.ndim(), a.shape() diff --git a/src/itertools.rs b/src/itertools.rs index d3562e68..8edfd75e 100644 --- a/src/itertools.rs +++ b/src/itertools.rs @@ -23,7 +23,8 @@ use std::iter; /// } /// ``` pub(crate) fn enumerate(iterable: I) -> iter::Enumerate -where I: IntoIterator +where + I: IntoIterator, { iterable.into_iter().enumerate() } diff --git a/src/layout/layoutfmt.rs b/src/layout/layoutfmt.rs index f20f0caa..3d7fad00 100644 --- a/src/layout/layoutfmt.rs +++ b/src/layout/layoutfmt.rs @@ -12,10 +12,8 @@ const LAYOUT_NAMES: &[&str] = &["C", "F", "c", "f"]; use std::fmt; -impl fmt::Debug for Layout -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Debug for Layout { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.0 == 0 { write!(f, "Custom")? } else { diff --git a/src/layout/mod.rs b/src/layout/mod.rs index 36853848..31f0b9a7 100644 --- a/src/layout/mod.rs +++ b/src/layout/mod.rs @@ -8,82 +8,70 @@ mod layoutfmt; #[derive(Copy, Clone)] pub struct Layout(u32); -impl Layout -{ +impl Layout { pub(crate) const CORDER: u32 = 0b01; pub(crate) const FORDER: u32 = 0b10; pub(crate) const CPREFER: u32 = 0b0100; pub(crate) const FPREFER: u32 = 0b1000; #[inline(always)] - pub(crate) fn is(self, flag: u32) -> bool - { + pub(crate) fn is(self, flag: u32) -> bool { self.0 & flag != 0 } /// Return layout common to both inputs #[inline(always)] - pub(crate) fn intersect(self, other: Layout) -> Layout - { + pub(crate) fn intersect(self, other: Layout) -> Layout { Layout(self.0 & other.0) } /// Return a layout that simultaneously "is" what both of the inputs are #[inline(always)] - pub(crate) fn also(self, other: Layout) -> Layout - { + pub(crate) fn also(self, other: Layout) -> Layout { Layout(self.0 | other.0) } #[inline(always)] - pub(crate) fn one_dimensional() -> Layout - { + pub(crate) fn one_dimensional() -> Layout { Layout::c().also(Layout::f()) } #[inline(always)] - pub(crate) fn c() -> Layout - { + pub(crate) fn c() -> Layout { Layout(Layout::CORDER | Layout::CPREFER) } #[inline(always)] - pub(crate) fn f() -> Layout - { + pub(crate) fn f() -> Layout { Layout(Layout::FORDER | Layout::FPREFER) } #[inline(always)] - pub(crate) fn cpref() -> Layout - { + pub(crate) fn cpref() -> Layout { Layout(Layout::CPREFER) } #[inline(always)] - pub(crate) fn fpref() -> Layout - { + pub(crate) fn fpref() -> Layout { Layout(Layout::FPREFER) } #[inline(always)] - pub(crate) fn none() -> Layout - { + pub(crate) fn none() -> Layout { Layout(0) } /// A simple "score" method which scores positive for preferring C-order, negative for F-order /// Subject to change when we can describe other layouts #[inline] - pub(crate) fn tendency(self) -> i32 - { + pub(crate) fn tendency(self) -> i32 { (self.is(Layout::CORDER) as i32 - self.is(Layout::FORDER) as i32) + (self.is(Layout::CPREFER) as i32 - self.is(Layout::FPREFER) as i32) } } #[cfg(test)] -mod tests -{ +mod tests { use super::*; use crate::imp_prelude::*; use crate::NdProducer; @@ -117,8 +105,7 @@ mod tests } #[test] - fn contig_layouts() - { + fn contig_layouts() { let a = M::zeros((5, 5)); let b = M::zeros((5, 5).f()); let ac = a.view().layout(); @@ -130,8 +117,7 @@ mod tests } #[test] - fn contig_cf_layouts() - { + fn contig_cf_layouts() { let a = M::zeros((5, 1)); let b = M::zeros((1, 5).f()); assert_layouts!(a, CORDER, CPREFER, FORDER, FPREFER); @@ -159,8 +145,7 @@ mod tests } #[test] - fn stride_layouts() - { + fn stride_layouts() { let a = M::zeros((5, 5)); { @@ -187,8 +172,7 @@ mod tests } #[test] - fn no_layouts() - { + fn no_layouts() { let a = M::zeros((5, 5)); let b = M::zeros((5, 5).f()); @@ -216,8 +200,7 @@ mod tests } #[test] - fn skip_layouts() - { + fn skip_layouts() { let a = M::zeros((5, 5)); { let v1 = a.slice(s![..;2, ..]).layout(); diff --git a/src/lib.rs b/src/lib.rs index 619457c9..10f8a967 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -280,13 +280,7 @@ pub mod backend; /// extra deps. Cognitive/research modules (p64_bridge, crystal_encoder, /// deepnsm, etc.) are gated behind `hpc-extras` inside `hpc/mod.rs`. #[cfg(feature = "std")] -#[allow( - clippy::all, - unused_imports, - unused_variables, - unused_mut, - dead_code -)] +#[allow(clippy::all, unused_imports, unused_variables, unused_mut, dead_code)] pub mod hpc; pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; @@ -294,24 +288,12 @@ pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; /// Implementation's prelude. Common types used everywhere. -mod imp_prelude -{ +mod imp_prelude { pub use crate::dimension::DimensionExt; pub use crate::prelude::*; pub use crate::ArcArray; pub use crate::{ - CowRepr, - Data, - DataMut, - DataOwned, - DataShared, - Ix, - Ixs, - RawData, - RawDataMut, - RawViewRepr, - RemoveAxis, - ViewRepr, + CowRepr, Data, DataMut, DataOwned, DataShared, Ix, Ixs, RawData, RawDataMut, RawViewRepr, RemoveAxis, ViewRepr, }; } @@ -1362,7 +1344,8 @@ pub type Ixs = isize; // // [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset-1 pub struct ArrayBase::Elem> -where S: RawData +where + S: RawData, { /// Data buffer / ownership information. (If owned, contains the data /// buffer; if borrowed, contains the lifetime and mutability.) @@ -1377,8 +1360,7 @@ where S: RawData /// type, which needs to be sized inside of `ArrayBase` and unsized inside /// of the reference types. #[derive(Debug)] -struct ArrayParts -{ +struct ArrayParts { /// A non-null pointer into the buffer held by `data`; may point anywhere /// in its range. If `S: Data`, this pointer must be aligned. ptr: NonNull, @@ -1392,10 +1374,8 @@ struct ArrayParts type ArrayPartsSized = ArrayParts; type ArrayPartsUnsized = ArrayParts; -impl ArrayPartsSized -{ - const fn new(ptr: NonNull, dim: D, strides: D) -> ArrayPartsSized - { +impl ArrayPartsSized { + const fn new(ptr: NonNull, dim: D, strides: D) -> ArrayPartsSized { Self { ptr, dim, @@ -1507,23 +1487,19 @@ impl ArrayPartsSized #[repr(transparent)] pub struct LayoutRef(ArrayPartsUnsized); -impl LayoutRef -{ +impl LayoutRef { /// Get a reference to the data pointer. - fn _ptr(&self) -> &NonNull - { + fn _ptr(&self) -> &NonNull { &self.0.ptr } /// Get a reference to the array's dimension. - fn _dim(&self) -> &D - { + fn _dim(&self) -> &D { &self.0.dim } /// Get a reference to the array's strides. - fn _strides(&self) -> &D - { + fn _strides(&self) -> &D { &self.0.strides } } @@ -1759,10 +1735,8 @@ pub use data_repr::OwnedRepr; #[derive(Debug)] pub struct OwnedArcRepr(Arc>); -impl Clone for OwnedArcRepr -{ - fn clone(&self) -> Self - { +impl Clone for OwnedArcRepr { + fn clone(&self) -> Self { OwnedArcRepr(self.0.clone()) } } @@ -1773,16 +1747,13 @@ impl Clone for OwnedArcRepr /// [`RawArrayView`] / [`RawArrayViewMut`] for the array type!* #[derive(Copy, Clone)] // This is just a marker type, to carry the mutability and element type. -pub struct RawViewRepr -{ +pub struct RawViewRepr { ptr: PhantomData, } -impl RawViewRepr -{ +impl RawViewRepr { #[inline(always)] - const fn new() -> Self - { + const fn new() -> Self { RawViewRepr { ptr: PhantomData } } } @@ -1793,16 +1764,13 @@ impl RawViewRepr /// [`ArrayView`] / [`ArrayViewMut`] for the array type!* #[derive(Copy, Clone)] // This is just a marker type, to carry the lifetime parameter. -pub struct ViewRepr -{ +pub struct ViewRepr { life: PhantomData, } -impl ViewRepr -{ +impl ViewRepr { #[inline(always)] - const fn new() -> Self - { + const fn new() -> Self { ViewRepr { life: PhantomData } } } @@ -1811,19 +1779,16 @@ impl ViewRepr /// /// *Don't use this type directly—use the type alias /// [`CowArray`] for the array type!* -pub enum CowRepr<'a, A> -{ +pub enum CowRepr<'a, A> { /// Borrowed data. View(ViewRepr<&'a A>), /// Owned data. Owned(OwnedRepr), } -impl CowRepr<'_, A> -{ +impl CowRepr<'_, A> { /// Returns `true` iff the data is the `View` variant. - pub fn is_view(&self) -> bool - { + pub fn is_view(&self) -> bool { match self { CowRepr::View(_) => true, CowRepr::Owned(_) => false, @@ -1831,8 +1796,7 @@ impl CowRepr<'_, A> } /// Returns `true` iff the data is the `Owned` variant. - pub fn is_owned(&self) -> bool - { + pub fn is_owned(&self) -> bool { match self { CowRepr::View(_) => false, CowRepr::Owned(_) => true, @@ -1854,11 +1818,11 @@ mod impl_owned_array; mod impl_special_element_types; /// Private Methods -impl ArrayRef -{ +impl ArrayRef { #[inline] fn broadcast_unwrap(&self, dim: E) -> ArrayView<'_, A, E> - where E: Dimension + where + E: Dimension, { #[cold] #[inline(never)] @@ -1880,7 +1844,8 @@ impl ArrayRef // (Checked in debug assertions). #[inline] fn broadcast_assume(&self, dim: E) -> ArrayView<'_, A, E> - where E: Dimension + where + E: Dimension, { let dim = dim.into_dimension(); debug_assert_eq!(self.shape(), dim.slice()); @@ -1897,8 +1862,7 @@ where D: Dimension, { /// Remove array axis `axis` and return the result. - fn try_remove_axis(self, axis: Axis) -> ArrayBase - { + fn try_remove_axis(self, axis: Axis) -> ArrayBase { let d = self.parts.dim.try_remove_axis(axis); let s = self.parts.strides.try_remove_axis(axis); // safe because new dimension, strides allow access to a subset of old data @@ -1937,8 +1901,7 @@ mod impl_cow; mod impl_arc_array; /// Returns `true` if the pointer is aligned. -pub(crate) fn is_aligned(ptr: *const T) -> bool -{ +pub(crate) fn is_aligned(ptr: *const T) -> bool { (ptr as usize).is_multiple_of(::std::mem::align_of::()) } diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 81c942bc..e0b52375 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -45,8 +45,7 @@ const GEMM_BLAS_CUTOFF: usize = 7; #[allow(non_camel_case_types)] type blas_index = c_int; // blas index type -impl ArrayRef -{ +impl ArrayRef { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// /// `Rhs` may be either a one-dimensional or a two-dimensional array. @@ -66,13 +65,15 @@ impl ArrayRef /// layout allows. #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where Self: Dot + where + Self: Dot, { Dot::dot(self, rhs) } fn dot_generic(&self, rhs: &ArrayRef) -> A - where A: LinalgScalar + where + A: LinalgScalar, { debug_assert_eq!(self.len(), rhs.len()); assert!(self.len() == rhs.len()); @@ -92,14 +93,16 @@ impl ArrayRef #[cfg(not(feature = "blas"))] fn dot_impl(&self, rhs: &ArrayRef) -> A - where A: LinalgScalar + where + A: LinalgScalar, { self.dot_generic(rhs) } #[cfg(feature = "blas")] fn dot_impl(&self, rhs: &ArrayRef) -> A - where A: LinalgScalar + where + A: LinalgScalar, { // Use only if the vector is large enough to be worth it if self.len() >= DOT_BLAS_CUTOFF { @@ -111,15 +114,8 @@ impl ArrayRef unsafe { let (lhs_ptr, n, incx) = blas_1d_params(self._ptr().as_ptr(), self.len(), self.strides()[0]); - let (rhs_ptr, _, incy) = - blas_1d_params(rhs._ptr().as_ptr(), rhs.len(), rhs.strides()[0]); - let ret = blas_sys::$func( - n, - lhs_ptr as *const $ty, - incx, - rhs_ptr as *const $ty, - incy, - ); + let (rhs_ptr, _, incy) = blas_1d_params(rhs._ptr().as_ptr(), rhs.len(), rhs.strides()[0]); + let ret = blas_sys::$func(n, lhs_ptr as *const $ty, incx, rhs_ptr as *const $ty, incy); return cast_as::<$ty, A>(&ret); } } @@ -139,8 +135,7 @@ impl ArrayRef /// which agrees with our pointer for non-negative strides, but /// is at the opposite end for negative strides. #[cfg(feature = "blas")] -unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) -{ +unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) { // [x x x x] // ^--ptr // stride = -1 @@ -157,8 +152,7 @@ unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const /// /// For two-dimensional arrays, the dot method computes the matrix /// multiplication. -pub trait Dot -{ +pub trait Dot { /// The result of the operation. /// /// For two-dimensional arrays: a rectangular array. @@ -183,8 +177,7 @@ macro_rules! impl_dots { { type Output = as Dot>>::Output; - fn dot(&self, rhs: &ArrayBase) -> Self::Output - { + fn dot(&self, rhs: &ArrayBase) -> Self::Output { Dot::dot(&**self, &**rhs) } } @@ -196,8 +189,7 @@ macro_rules! impl_dots { { type Output = as Dot>>::Output; - fn dot(&self, rhs: &ArrayRef) -> Self::Output - { + fn dot(&self, rhs: &ArrayRef) -> Self::Output { (**self).dot(rhs) } } @@ -209,8 +201,7 @@ macro_rules! impl_dots { { type Output = as Dot>>::Output; - fn dot(&self, rhs: &ArrayBase) -> Self::Output - { + fn dot(&self, rhs: &ArrayBase) -> Self::Output { self.dot(&**rhs) } } @@ -223,7 +214,8 @@ impl_dots!(Ix2, Ix1); impl_dots!(Ix2, Ix2); impl Dot> for ArrayRef -where A: LinalgScalar +where + A: LinalgScalar, { type Output = A; @@ -236,14 +228,14 @@ where A: LinalgScalar /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. #[track_caller] - fn dot(&self, rhs: &ArrayRef) -> A - { + fn dot(&self, rhs: &ArrayRef) -> A { self.dot_impl(rhs) } } impl Dot> for ArrayRef -where A: LinalgScalar +where + A: LinalgScalar, { type Output = Array; @@ -257,14 +249,12 @@ where A: LinalgScalar /// /// **Panics** if shapes are incompatible. #[track_caller] - fn dot(&self, rhs: &ArrayRef) -> Array - { + fn dot(&self, rhs: &ArrayRef) -> Array { (*rhs.t()).dot(self) } } -impl ArrayRef -{ +impl ArrayRef { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// /// `Rhs` may be either a one-dimensional or a two-dimensional array. @@ -296,19 +286,20 @@ impl ArrayRef /// ``` #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where Self: Dot + where + Self: Dot, { Dot::dot(self, rhs) } } impl Dot> for ArrayRef -where A: LinalgScalar +where + A: LinalgScalar, { type Output = Array2; - fn dot(&self, b: &ArrayRef) -> Array2 - { + fn dot(&self, b: &ArrayRef) -> Array2 { let a = self.view(); let b = b.view(); let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -334,24 +325,21 @@ where A: LinalgScalar /// Assumes that `m` and `n` are ≤ `isize::MAX`. #[cold] #[inline(never)] -fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! -{ +fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! { match m.checked_mul(n) { Some(len) if len <= isize::MAX as usize => {} _ => panic!("ndarray: shape {} × {} overflows isize", m, n), } - panic!( - "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", - m, k, k2, n - ); + panic!("ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", m, k, k2, n); } #[cold] #[inline(never)] -fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! -{ - panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", - m, k, k2, n, c1, c2); +fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! { + panic!( + "ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", + m, k, k2, n, c1, c2 + ); } /// Perform the matrix multiplication of the rectangular array `self` and @@ -364,13 +352,13 @@ fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c /// /// **Panics** if shapes are incompatible. impl Dot> for ArrayRef -where A: LinalgScalar +where + A: LinalgScalar, { type Output = Array; #[track_caller] - fn dot(&self, rhs: &ArrayRef) -> Array - { + fn dot(&self, rhs: &ArrayRef) -> Array { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { dot_shape_error(m, a, n, 1); @@ -386,7 +374,8 @@ where A: LinalgScalar } impl ArrayRef -where D: Dimension +where + D: Dimension, { /// Perform the operation `self += alpha * rhs` efficiently, where /// `alpha` is a scalar and `rhs` is another array. This operation is @@ -412,7 +401,8 @@ use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] fn mat_mul_impl(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) -where A: LinalgScalar +where + A: LinalgScalar, { let ((m, k), (k2, n)) = (a.dim(), b.dim()); debug_assert_eq!(k, k2); @@ -467,17 +457,17 @@ where A: LinalgScalar cblas_layout, a_trans, b_trans, - m as blas_index, // m, rows of Op(a) - n as blas_index, // n, cols of Op(b) - k as blas_index, // k, cols of Op(a) - gemm_scalar_cast!($ty, alpha), // alpha - a._ptr().as_ptr() as *const _, // a - lda, // lda - b._ptr().as_ptr() as *const _, // b - ldb, // ldb - gemm_scalar_cast!($ty, beta), // beta - c._ptr().as_ptr() as *mut _, // c - ldc, // ldc + m as blas_index, // m, rows of Op(a) + n as blas_index, // n, cols of Op(b) + k as blas_index, // k, cols of Op(a) + gemm_scalar_cast!($ty, alpha), // alpha + a._ptr().as_ptr() as *const _, // a + lda, // lda + b._ptr().as_ptr() as *const _, // b + ldb, // ldb + gemm_scalar_cast!($ty, beta), // beta + c._ptr().as_ptr() as *mut _, // c + ldc, // ldc ); } return; @@ -498,7 +488,8 @@ where A: LinalgScalar /// C ← α A B + β C fn mat_mul_general(alpha: A, lhs: &ArrayRef2, rhs: &ArrayRef2, beta: A, c: &mut ArrayRef2) -where A: LinalgScalar +where + A: LinalgScalar, { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -631,7 +622,8 @@ where A: LinalgScalar /// `f32, f64` for all memory layouts. #[track_caller] pub fn general_mat_mul(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) -where A: LinalgScalar +where + A: LinalgScalar, { let ((m, k), (k2, n)) = (a.dim(), b.dim()); let (m2, n2) = c.dim(); @@ -655,7 +647,8 @@ where A: LinalgScalar #[track_caller] #[allow(clippy::collapsible_if)] pub fn general_mat_vec_mul(alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: &mut ArrayRef1) -where A: LinalgScalar +where + A: LinalgScalar, { unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) } } @@ -669,9 +662,9 @@ where A: LinalgScalar /// The caller must ensure that the raw view is valid for writing. /// the destination may be uninitialized iff beta is zero. #[allow(clippy::collapsible_else_if)] -unsafe fn general_mat_vec_mul_impl( - alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: RawArrayViewMut, -) where A: LinalgScalar +unsafe fn general_mat_vec_mul_impl(alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: RawArrayViewMut) +where + A: LinalgScalar, { let ((m, k), k2) = (a.dim(), x.dim()); let m2 = y.dim(); @@ -705,15 +698,15 @@ unsafe fn general_mat_vec_mul_impl( blas_sys::$gemv( cblas_layout, a_trans, - m as blas_index, // m, rows of Op(a) - k as blas_index, // n, cols of Op(a) - cast_as(&alpha), // alpha + m as blas_index, // m, rows of Op(a) + k as blas_index, // n, cols of Op(a) + cast_as(&alpha), // alpha a._ptr().as_ptr() as *const _, // a - a_stride, // lda - x_ptr as *const _, // x + a_stride, // lda + x_ptr as *const _, // x x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y + cast_as(&beta), // beta + y_ptr as *mut _, // y y_stride, ); return; @@ -747,7 +740,8 @@ unsafe fn general_mat_vec_mul_impl( /// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) /// matrix K formed by the block multiplication A_ij * B. pub fn kron(a: &ArrayRef2, b: &ArrayRef2) -> Array -where A: LinalgScalar +where + A: LinalgScalar, { let dimar = a.shape()[0]; let dimac = a.shape()[1]; @@ -773,25 +767,26 @@ where A: LinalgScalar #[inline(always)] /// Return `true` if `A` and `B` are the same type -fn same_type() -> bool -{ +fn same_type() -> bool { TypeId::of::() == TypeId::of::() } // Read pointer to type `A` as type `B`. // // **Panics** if `A` and `B` are not the same type -fn cast_as(a: &A) -> B -{ - assert!(same_type::(), "expect type {} and {} to match", - std::any::type_name::(), std::any::type_name::()); +fn cast_as(a: &A) -> B { + assert!( + same_type::(), + "expect type {} and {} to match", + std::any::type_name::(), + std::any::type_name::() + ); unsafe { ::std::ptr::read(a as *const _ as *const B) } } /// Return the complex in the form of an array [re, im] #[inline] -fn complex_array(z: Complex) -> [A; 2] -{ +fn complex_array(z: Complex) -> [A; 2] { [z.re, z.im] } @@ -817,17 +812,14 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] #[cfg_attr(test, derive(PartialEq, Eq, Debug))] -enum BlasOrder -{ +enum BlasOrder { C, F, } #[cfg(feature = "blas")] -impl BlasOrder -{ - fn transpose(self) -> Self - { +impl BlasOrder { + fn transpose(self) -> Self { match self { Self::C => Self::F, Self::F => Self::C, @@ -836,16 +828,14 @@ impl BlasOrder #[inline] /// Axis of leading stride (opposite of contiguous axis) - fn get_blas_lead_axis(self) -> usize - { + fn get_blas_lead_axis(self) -> usize { match self { Self::C => 0, Self::F => 1, } } - fn to_cblas_layout(self) -> CBLAS_LAYOUT - { + fn to_cblas_layout(self) -> CBLAS_LAYOUT { match self { Self::C => CBLAS_LAYOUT::CblasRowMajor, Self::F => CBLAS_LAYOUT::CblasColMajor, @@ -854,8 +844,7 @@ impl BlasOrder /// When using cblas_sgemm (etc) with C matrix using `for_layout`, /// how should this `self` matrix be transposed - fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE - { + fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE { let effective_order = match for_layout { CBLAS_LAYOUT::CblasRowMajor => self, CBLAS_LAYOUT::CblasColMajor => self.transpose(), @@ -869,8 +858,7 @@ impl BlasOrder } #[cfg(feature = "blas")] -fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool -{ +fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool { let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; @@ -907,8 +895,7 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] -fn get_blas_compatible_layout(a: &ArrayRef) -> Option -{ +fn get_blas_compatible_layout(a: &ArrayRef) -> Option { if is_blas_2d(a._dim(), a._strides(), BlasOrder::C) { Some(BlasOrder::C) } else if is_blas_2d(a._dim(), a._strides(), BlasOrder::F) { @@ -923,8 +910,7 @@ fn get_blas_compatible_layout(a: &ArrayRef) -> Option /// /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_stride(a: &ArrayRef, order: BlasOrder) -> blas_index -{ +fn blas_stride(a: &ArrayRef, order: BlasOrder) -> blas_index { let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; let len_this = a.shape()[axis]; @@ -986,12 +972,12 @@ where /// - The array shapes are incompatible for the operation /// - For vector dot product: the vectors have different lengths impl Dot> for ArrayRef -where A: LinalgScalar +where + A: LinalgScalar, { type Output = Array; - fn dot(&self, rhs: &ArrayRef) -> Self::Output - { + fn dot(&self, rhs: &ArrayRef) -> Self::Output { match (self.ndim(), rhs.ndim()) { (1, 1) => { let a = self.view().into_dimensionality::().unwrap(); @@ -1027,37 +1013,32 @@ where A: LinalgScalar #[cfg(test)] #[cfg(feature = "blas")] -mod blas_tests -{ +mod blas_tests { use super::*; #[test] - fn blas_row_major_2d_normal_matrix() - { + fn blas_row_major_2d_normal_matrix() { let m: Array2 = Array2::zeros((3, 5)); assert!(blas_row_major_2d::(&m)); assert!(!blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_row_matrix() - { + fn blas_row_major_2d_row_matrix() { let m: Array2 = Array2::zeros((1, 5)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_column_matrix() - { + fn blas_row_major_2d_column_matrix() { let m: Array2 = Array2::zeros((5, 1)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_transposed_row_matrix() - { + fn blas_row_major_2d_transposed_row_matrix() { let m: Array2 = Array2::zeros((1, 5)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -1065,8 +1046,7 @@ mod blas_tests } #[test] - fn blas_row_major_2d_transposed_column_matrix() - { + fn blas_row_major_2d_transposed_column_matrix() { let m: Array2 = Array2::zeros((5, 1)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -1074,16 +1054,14 @@ mod blas_tests } #[test] - fn blas_column_major_2d_normal_matrix() - { + fn blas_column_major_2d_normal_matrix() { let m: Array2 = Array2::zeros((3, 5).f()); assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_skip_rows_ok() - { + fn blas_row_major_2d_skip_rows_ok() { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![..;2, ..]); assert!(blas_row_major_2d::(&mv)); @@ -1091,8 +1069,7 @@ mod blas_tests } #[test] - fn blas_row_major_2d_skip_columns_fail() - { + fn blas_row_major_2d_skip_columns_fail() { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![.., ..;2]); assert!(!blas_row_major_2d::(&mv)); @@ -1100,8 +1077,7 @@ mod blas_tests } #[test] - fn blas_col_major_2d_skip_columns_ok() - { + fn blas_col_major_2d_skip_columns_ok() { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![.., ..;2]); assert!(blas_column_major_2d::(&mv)); @@ -1109,8 +1085,7 @@ mod blas_tests } #[test] - fn blas_col_major_2d_skip_rows_fail() - { + fn blas_col_major_2d_skip_rows_fail() { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![..;2, ..]); assert!(!blas_column_major_2d::(&mv)); @@ -1118,8 +1093,7 @@ mod blas_tests } #[test] - fn blas_too_short_stride() - { + fn blas_too_short_stride() { // leading stride must be longer than the other dimension // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs index e7723583..a951d355 100644 --- a/src/linalg_traits.rs +++ b/src/linalg_traits.rs @@ -28,8 +28,10 @@ pub trait LinalgScalar: { } -impl LinalgScalar for T where T: 'static + Copy + Zero + One + Add + Sub + Mul + Div -{} +impl LinalgScalar for T where + T: 'static + Copy + Zero + One + Add + Sub + Mul + Div +{ +} /// Floating-point element types `f32` and `f64`. /// diff --git a/src/linspace.rs b/src/linspace.rs index ff52bf0c..fb79b777 100644 --- a/src/linspace.rs +++ b/src/linspace.rs @@ -14,8 +14,7 @@ use num_traits::Float; /// An iterator of a sequence of evenly spaced floats. /// /// Iterator element type is `F`. -pub struct Linspace -{ +pub struct Linspace { start: F, step: F, index: usize, @@ -23,13 +22,13 @@ pub struct Linspace } impl Iterator for Linspace -where F: Float +where + F: Float, { type Item = F; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if self.index >= self.len { None } else { @@ -41,19 +40,18 @@ where F: Float } #[inline] - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Linspace -where F: Float +where + F: Float, { #[inline] - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { if self.index >= self.len { None } else { @@ -111,7 +109,8 @@ where /// **Panics** if converting `((b - a) / step).ceil()` to type `F` fails. #[inline] pub fn range(a: F, b: F, step: F) -> Linspace -where F: Float +where + F: Float, { let len = b - a; let steps = F::ceil(len / step); diff --git a/src/logspace.rs b/src/logspace.rs index dd1b7ae1..fa019244 100644 --- a/src/logspace.rs +++ b/src/logspace.rs @@ -14,8 +14,7 @@ use num_traits::Float; /// An iterator of a sequence of logarithmically spaced number. /// /// Iterator element type is `F`. -pub struct Logspace -{ +pub struct Logspace { sign: F, base: F, start: F, @@ -25,13 +24,13 @@ pub struct Logspace } impl Iterator for Logspace -where F: Float +where + F: Float, { type Item = F; #[inline] - fn next(&mut self) -> Option - { + fn next(&mut self) -> Option { if self.index >= self.len { None } else { @@ -44,19 +43,18 @@ where F: Float } #[inline] - fn size_hint(&self) -> (usize, Option) - { + fn size_hint(&self) -> (usize, Option) { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Logspace -where F: Float +where + F: Float, { #[inline] - fn next_back(&mut self) -> Option - { + fn next_back(&mut self) -> Option { if self.index >= self.len { None } else { @@ -109,14 +107,12 @@ where } #[cfg(test)] -mod tests -{ +mod tests { use super::logspace; #[test] #[cfg(feature = "approx")] - fn valid() - { + fn valid() { use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; @@ -134,8 +130,7 @@ mod tests } #[test] - fn iter_forward() - { + fn iter_forward() { let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4))); @@ -150,8 +145,7 @@ mod tests } #[test] - fn iter_backward() - { + fn iter_backward() { let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4))); diff --git a/src/low_level_util.rs b/src/low_level_util.rs index 5a615a18..e7555488 100644 --- a/src/low_level_util.rs +++ b/src/low_level_util.rs @@ -13,22 +13,18 @@ #[must_use] pub(crate) struct AbortIfPanic(pub(crate) &'static &'static str); -impl AbortIfPanic -{ +impl AbortIfPanic { /// Defuse the AbortIfPanic guard. This *must* be done when finished. #[inline] - pub(crate) fn defuse(self) - { + pub(crate) fn defuse(self) { std::mem::forget(self); } } -impl Drop for AbortIfPanic -{ +impl Drop for AbortIfPanic { // The compiler should be able to remove this, if it can see through that there // is no panic in the code section. - fn drop(&mut self) - { + fn drop(&mut self) { #[cfg(feature = "std")] { eprintln!("ndarray: panic in no-panic section, aborting: {}", self.0); diff --git a/src/macro_utils.rs b/src/macro_utils.rs index 34c700e6..54507056 100644 --- a/src/macro_utils.rs +++ b/src/macro_utils.rs @@ -61,12 +61,11 @@ macro_rules! expand_if { #[cfg(debug_assertions)] macro_rules! debug_bounds_check { ($self_:ident, $index:expr) => { - if $index.index_checked(&$self_._dim(), &$self_._strides()).is_none() { - panic!( - "ndarray: index {:?} is out of bounds for array of shape {:?}", - $index, - $self_.shape() - ); + if $index + .index_checked(&$self_._dim(), &$self_._strides()) + .is_none() + { + panic!("ndarray: index {:?} is out of bounds for array of shape {:?}", $index, $self_.shape()); } }; } @@ -79,12 +78,11 @@ macro_rules! debug_bounds_check { #[cfg(debug_assertions)] macro_rules! debug_bounds_check_ref { ($self_:ident, $index:expr) => { - if $index.index_checked(&$self_._dim(), &$self_._strides()).is_none() { - panic!( - "ndarray: index {:?} is out of bounds for array of shape {:?}", - $index, - $self_.shape() - ); + if $index + .index_checked(&$self_._dim(), &$self_._strides()) + .is_none() + { + panic!("ndarray: index {:?} is out of bounds for array of shape {:?}", $index, $self_.shape()); } }; } diff --git a/src/math_cell.rs b/src/math_cell.rs index 629e5575..7971b45d 100644 --- a/src/math_cell.rs +++ b/src/math_cell.rs @@ -13,61 +13,53 @@ use std::ops::{Deref, DerefMut}; #[derive(Default)] pub struct MathCell(Cell); -impl MathCell -{ +impl MathCell { /// Create a new cell with the given value #[inline(always)] - pub const fn new(value: T) -> Self - { + pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) } /// Return the inner value - pub fn into_inner(self) -> T - { + pub fn into_inner(self) -> T { Cell::into_inner(self.0) } /// Swap value with another cell - pub fn swap(&self, other: &Self) - { + pub fn swap(&self, other: &Self) { Cell::swap(&self.0, &other.0) } } -impl Deref for MathCell -{ +impl Deref for MathCell { type Target = Cell; #[inline(always)] - fn deref(&self) -> &Self::Target - { + fn deref(&self) -> &Self::Target { &self.0 } } -impl DerefMut for MathCell -{ +impl DerefMut for MathCell { #[inline(always)] - fn deref_mut(&mut self) -> &mut Self::Target - { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl Clone for MathCell -where T: Copy +where + T: Copy, { - fn clone(&self) -> Self - { + fn clone(&self) -> Self { MathCell::new(self.get()) } } impl PartialEq for MathCell -where T: Copy + PartialEq +where + T: Copy + PartialEq, { - fn eq(&self, rhs: &Self) -> bool - { + fn eq(&self, rhs: &Self) -> bool { self.get() == rhs.get() } } @@ -75,57 +67,51 @@ where T: Copy + PartialEq impl Eq for MathCell where T: Copy + Eq {} impl PartialOrd for MathCell -where T: Copy + PartialOrd +where + T: Copy + PartialOrd, { - fn partial_cmp(&self, rhs: &Self) -> Option - { + fn partial_cmp(&self, rhs: &Self) -> Option { self.get().partial_cmp(&rhs.get()) } - fn lt(&self, rhs: &Self) -> bool - { + fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) } - fn le(&self, rhs: &Self) -> bool - { + fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) } - fn gt(&self, rhs: &Self) -> bool - { + fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) } - fn ge(&self, rhs: &Self) -> bool - { + fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) } } impl Ord for MathCell -where T: Copy + Ord +where + T: Copy + Ord, { - fn cmp(&self, rhs: &Self) -> Ordering - { + fn cmp(&self, rhs: &Self) -> Ordering { self.get().cmp(&rhs.get()) } } impl fmt::Debug for MathCell -where T: Copy + fmt::Debug +where + T: Copy + fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result - { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.get().fmt(f) } } #[cfg(test)] -mod tests -{ +mod tests { use super::MathCell; #[test] - fn test_basic() - { + fn test_basic() { let c = &MathCell::new(0); c.set(1); assert_eq!(c.get(), 1); diff --git a/src/numeric/impl_float_maths.rs b/src/numeric/impl_float_maths.rs index 6d6ebce5..33f3c325 100644 --- a/src/numeric/impl_float_maths.rs +++ b/src/numeric/impl_float_maths.rs @@ -162,8 +162,7 @@ where /// Square (two powers) of each element. #[must_use = "method returns a new array and does not mutate the original value"] - pub fn pow2(&self) -> Array - { + pub fn pow2(&self) -> Array { self.mapv(|v: A| v * v) } } @@ -192,8 +191,7 @@ where /// # See Also /// [ArrayRef::to_complex_im] #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_complex_re(&self) -> Array, D> - { + pub fn to_complex_re(&self) -> Array, D> { self.mapv(|v| Complex::new(v, A::zero())) } @@ -215,8 +213,7 @@ where /// # See Also /// [ArrayRef::to_complex_re] #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_complex_im(&self) -> Array, D> - { + pub fn to_complex_im(&self) -> Array, D> { self.mapv(|v| Complex::new(A::zero(), v)) } } @@ -254,8 +251,7 @@ where /// assert!((angles[2] - PI/4.0).abs() < 1e-10); /// ``` #[must_use = "method returns a new array and does not mutate the original value"] - pub fn angle(&self) -> Array - { + pub fn angle(&self) -> Array { self.mapv(|v| v.im.atan2(v.re)) } } @@ -278,16 +274,14 @@ where /// # Panics /// /// Panics if `!(min <= max)`. - pub fn clamp(&self, min: A, max: A) -> Array - { + pub fn clamp(&self, min: A, max: A) -> Array { assert!(min <= max, "min must be less than or equal to max"); self.mapv(|a| num_traits::clamp(a, min.clone(), max.clone())) } } #[cfg(all(test, feature = "std"))] -mod angle_tests -{ +mod angle_tests { use crate::Array; use num_complex::Complex; use std::f64::consts::PI; @@ -311,10 +305,9 @@ mod angle_tests } #[test] - fn test_complex_numbers_radians() - { + fn test_complex_numbers_radians() { let arr = Array::from_vec(vec![ - Complex::new(1.0f64, 0.0), // 0 + Complex::new(1.0f64, 0.0), // 0 Complex::new(0.0, 1.0), // π/2 Complex::new(-1.0, 0.0), // π Complex::new(0.0, -1.0), // -π/2 @@ -332,8 +325,7 @@ mod angle_tests } #[test] - fn test_complex_numbers_degrees() - { + fn test_complex_numbers_degrees() { let arr = Array::from_vec(vec![ Complex::new(1.0f64, 0.0), Complex::new(0.0, 1.0), @@ -349,8 +341,7 @@ mod angle_tests } #[test] - fn test_signed_zeros() - { + fn test_signed_zeros() { let arr = Array::from_vec(vec![ Complex::new(0.0f64, 0.0), Complex::new(-0.0, 0.0), @@ -366,8 +357,7 @@ mod angle_tests } #[test] - fn test_edge_cases() - { + fn test_edge_cases() { let arr = Array::from_vec(vec![ Complex::new(f64::INFINITY, 0.0), Complex::new(0.0, f64::INFINITY), @@ -383,8 +373,7 @@ mod angle_tests } #[test] - fn test_range_validation() - { + fn test_range_validation() { let n = 16; let complex_arr: Vec<_> = (0..n) .map(|i| { diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 90e9b6ec..56c6a75f 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -18,7 +18,8 @@ use crate::Slice; /// # Numerical Methods for Arrays impl ArrayRef -where D: Dimension +where + D: Dimension, { /// Return the sum of all elements in the array. /// @@ -30,7 +31,8 @@ where D: Dimension /// assert_eq!(a.sum(), 10.); /// ``` pub fn sum(&self) -> A - where A: Clone + Add + num_traits::Zero + where + A: Clone + Add + num_traits::Zero, { if let Some(slc) = self.as_slice_memory_order() { return numeric_util::unrolled_fold(slc, A::zero, A::add); @@ -60,7 +62,8 @@ where D: Dimension /// /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean pub fn mean(&self) -> Option - where A: Clone + FromPrimitive + Add + Div + Zero + where + A: Clone + FromPrimitive + Add + Div + Zero, { let n_elements = self.len(); if n_elements == 0 { @@ -81,7 +84,8 @@ where D: Dimension /// assert_eq!(a.product(), 24.); /// ``` pub fn product(&self) -> A - where A: Clone + Mul + num_traits::One + where + A: Clone + Mul + num_traits::One, { if let Some(slc) = self.as_slice_memory_order() { return numeric_util::unrolled_fold(slc, A::one, A::mul); @@ -178,7 +182,8 @@ where D: Dimension #[track_caller] #[cfg(feature = "std")] pub fn var(&self, ddof: A) -> A - where A: Float + FromPrimitive + where + A: Float + FromPrimitive, { let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail."); @@ -243,7 +248,8 @@ where D: Dimension #[track_caller] #[cfg(feature = "std")] pub fn std(&self, ddof: A) -> A - where A: Float + FromPrimitive + where + A: Float + FromPrimitive, { self.var(ddof).sqrt() } @@ -493,7 +499,8 @@ where D: Dimension /// array![1.0, 2.0, 3.0].diff(10, Axis(0)); /// ``` pub fn diff(&self, n: usize, axis: Axis) -> Array - where A: Sub + Zero + Clone + where + A: Sub + Zero + Clone, { if n == 0 { return self.to_owned(); diff --git a/src/numeric_util.rs b/src/numeric_util.rs index 9d5ce66c..1879abdb 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -54,7 +54,8 @@ where /// /// `xs` and `ys` must be the same length pub fn unrolled_dot(xs: &[A], ys: &[A]) -> A -where A: LinalgScalar +where + A: LinalgScalar, { debug_assert_eq!(xs.len(), ys.len()); // eightfold unrolled so that floating point can be vectorized @@ -96,7 +97,8 @@ where A: LinalgScalar /// /// `xs` and `ys` must be the same length pub fn unrolled_eq(xs: &[A], ys: &[B]) -> bool -where A: PartialEq +where + A: PartialEq, { debug_assert_eq!(xs.len(), ys.len()); // eightfold unrolled for performance (this is not done by llvm automatically) diff --git a/src/order.rs b/src/order.rs index a52a32e2..4ab8c84e 100644 --- a/src/order.rs +++ b/src/order.rs @@ -30,16 +30,14 @@ /// or "Fortran" order. #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[non_exhaustive] -pub enum Order -{ +pub enum Order { /// Row major or "C" order RowMajor, /// Column major or "F" order ColumnMajor, } -impl Order -{ +impl Order { /// "C" is an alias for row major ordering pub const C: Order = Order::RowMajor; @@ -48,8 +46,7 @@ impl Order /// Return true if input is Order::RowMajor, false otherwise #[inline] - pub fn is_row_major(self) -> bool - { + pub fn is_row_major(self) -> bool { match self { Order::RowMajor => true, Order::ColumnMajor => false, @@ -58,15 +55,13 @@ impl Order /// Return true if input is Order::ColumnMajor, false otherwise #[inline] - pub fn is_column_major(self) -> bool - { + pub fn is_column_major(self) -> bool { !self.is_row_major() } /// Return Order::RowMajor if the input is true, Order::ColumnMajor otherwise #[inline] - pub fn row_major(row_major: bool) -> Order - { + pub fn row_major(row_major: bool) -> Order { if row_major { Order::RowMajor } else { @@ -76,15 +71,13 @@ impl Order /// Return Order::ColumnMajor if the input is true, Order::RowMajor otherwise #[inline] - pub fn column_major(column_major: bool) -> Order - { + pub fn column_major(column_major: bool) -> Order { Self::row_major(!column_major) } /// Return the transpose: row major becomes column major and vice versa. #[inline] - pub fn transpose(self) -> Order - { + pub fn transpose(self) -> Order { match self { Order::RowMajor => Order::ColumnMajor, Order::ColumnMajor => Order::RowMajor, diff --git a/src/parallel/impl_par_methods.rs b/src/parallel/impl_par_methods.rs index f825daa5..8f99adbb 100644 --- a/src/parallel/impl_par_methods.rs +++ b/src/parallel/impl_par_methods.rs @@ -19,7 +19,8 @@ where /// /// Elements are visited in arbitrary order. pub fn par_map_inplace(&mut self, f: F) - where F: Fn(&mut A) + Sync + Send + where + F: Fn(&mut A) + Sync + Send, { self.view_mut().into_par_iter().for_each(f) } diff --git a/src/parallel/into_impls.rs b/src/parallel/into_impls.rs index 75bded7d..c1a5388f 100644 --- a/src/parallel/into_impls.rs +++ b/src/parallel/into_impls.rs @@ -11,8 +11,7 @@ where { type Item = &'a A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter - { + fn into_par_iter(self) -> Self::Iter { self.view().into_par_iter() } } @@ -26,8 +25,7 @@ where { type Item = &'a A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter - { + fn into_par_iter(self) -> Self::Iter { self.view().into_par_iter() } } @@ -40,8 +38,7 @@ where { type Item = &'a mut A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter - { + fn into_par_iter(self) -> Self::Iter { self.view_mut().into_par_iter() } } @@ -55,8 +52,7 @@ where { type Item = &'a mut A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter - { + fn into_par_iter(self) -> Self::Iter { self.view_mut().into_par_iter() } } diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 3ac0d4b0..15537cf7 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -123,14 +123,10 @@ use crate::iter::{AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut}; use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, Zip}; /// Into- traits for creating parallelized iterators and/or using [`par_azip!`] -pub mod prelude -{ +pub mod prelude { #[doc(no_inline)] pub use rayon::prelude::{ - IndexedParallelIterator, - IntoParallelIterator, - IntoParallelRefIterator, - IntoParallelRefMutIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; diff --git a/src/parallel/par.rs b/src/parallel/par.rs index efff8e6b..c3bcc7a8 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -19,8 +19,7 @@ use crate::{ArrayView, ArrayViewMut}; /// Parallel iterator wrapper. #[derive(Copy, Clone, Debug)] -pub struct Parallel -{ +pub struct Parallel { iter: I, min_len: usize, } @@ -313,15 +312,15 @@ zip_impl! { } impl Parallel> -where D: Dimension +where + D: Dimension, { /// Sets the minimum number of elements desired to process in each job. This will not be /// split any smaller than this length, but of course a producer could already be smaller /// to begin with. /// /// ***Panics*** if `min_len` is zero. - pub fn with_min_len(self, min_len: usize) -> Self - { + pub fn with_min_len(self, min_len: usize) -> Self { assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks."); Self { min_len, ..self } @@ -330,36 +329,36 @@ where D: Dimension /// A parallel iterator (unindexed) that produces the splits of the array /// or producer `P`. -pub(crate) struct ParallelSplits

-{ +pub(crate) struct ParallelSplits

{ pub(crate) iter: P, pub(crate) max_splits: usize, } impl

ParallelIterator for ParallelSplits

-where P: SplitPreference + Send +where + P: SplitPreference + Send, { type Item = P; fn drive_unindexed(self, consumer: C) -> C::Result - where C: UnindexedConsumer + where + C: UnindexedConsumer, { bridge_unindexed(self, consumer) } - fn opt_len(&self) -> Option - { + fn opt_len(&self) -> Option { None } } impl

UnindexedProducer for ParallelSplits

-where P: SplitPreference + Send +where + P: SplitPreference + Send, { type Item = P; - fn split(self) -> (Self, Option) - { + fn split(self) -> (Self, Option) { if self.max_splits == 0 || !self.iter.can_split() { return (self, None); } @@ -377,7 +376,8 @@ where P: SplitPreference + Send } fn fold_with(self, folder: Fold) -> Fold - where Fold: Folder + where + Fold: Folder, { folder.consume(self.iter) } diff --git a/src/parallel/send_producer.rs b/src/parallel/send_producer.rs index ecfb77af..f15cc89a 100644 --- a/src/parallel/send_producer.rs +++ b/src/parallel/send_producer.rs @@ -4,41 +4,35 @@ use std::ops::{Deref, DerefMut}; /// An NdProducer that is unconditionally `Send`. #[repr(transparent)] -pub(crate) struct SendProducer -{ +pub(crate) struct SendProducer { inner: T, } -impl SendProducer -{ +impl SendProducer { /// Create an unconditionally `Send` ndproducer from the producer - pub(crate) unsafe fn new(producer: T) -> Self - { + pub(crate) unsafe fn new(producer: T) -> Self { Self { inner: producer } } } unsafe impl

Send for SendProducer

{} -impl

Deref for SendProducer

-{ +impl

Deref for SendProducer

{ type Target = P; - fn deref(&self) -> &P - { + fn deref(&self) -> &P { &self.inner } } -impl

DerefMut for SendProducer

-{ - fn deref_mut(&mut self) -> &mut P - { +impl

DerefMut for SendProducer

{ + fn deref_mut(&mut self) -> &mut P { &mut self.inner } } impl NdProducer for SendProducer

-where P: NdProducer +where + P: NdProducer, { type Item = P::Item; type Dim = P::Dim; @@ -48,55 +42,46 @@ where P: NdProducer private_impl! {} #[inline(always)] - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { self.inner.raw_dim() } #[inline(always)] - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self.inner.equal_dim(dim) } #[inline(always)] - fn as_ptr(&self) -> Self::Ptr - { + fn as_ptr(&self) -> Self::Ptr { self.inner.as_ptr() } #[inline(always)] - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { self.inner.layout() } #[inline(always)] - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item - { + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { self.inner.as_ref(ptr) } #[inline(always)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { self.inner.uget_ptr(i) } #[inline(always)] - fn stride_of(&self, axis: Axis) -> Self::Stride - { + fn stride_of(&self, axis: Axis) -> Self::Stride { self.inner.stride_of(axis) } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { self.inner.contiguous_stride() } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { let (a, b) = self.inner.split_at(axis, index); (Self { inner: a }, Self { inner: b }) } diff --git a/src/partial.rs b/src/partial.rs index dbaa0e10..51f6eb8a 100644 --- a/src/partial.rs +++ b/src/partial.rs @@ -12,16 +12,14 @@ use std::ptr; /// it is the owner of the elements, but not the allocation, /// and will drop the elements on drop. #[must_use] -pub(crate) struct Partial -{ +pub(crate) struct Partial { /// Data pointer ptr: *mut T, /// Current length pub(crate) len: usize, } -impl Partial -{ +impl Partial { /// Create an empty partial for this data pointer /// /// ## Safety @@ -31,14 +29,12 @@ impl Partial /// the `len` elements following it valid. /// /// The Partial has an accessible length field which must only be modified in trusted code. - pub(crate) unsafe fn new(ptr: *mut T) -> Self - { + pub(crate) unsafe fn new(ptr: *mut T) -> Self { Self { ptr, len: 0 } } #[cfg(feature = "rayon")] - pub(crate) fn stub() -> Self - { + pub(crate) fn stub() -> Self { Self { len: 0, ptr: ptr::null_mut(), @@ -46,14 +42,12 @@ impl Partial } #[cfg(feature = "rayon")] - pub(crate) fn is_stub(&self) -> bool - { + pub(crate) fn is_stub(&self) -> bool { self.ptr.is_null() } /// Release Partial's ownership of the written elements, and return the current length - pub(crate) fn release_ownership(mut self) -> usize - { + pub(crate) fn release_ownership(mut self) -> usize { let ret = self.len; self.len = 0; ret @@ -62,8 +56,7 @@ impl Partial #[cfg(feature = "rayon")] /// Merge if they are in order (left to right) and contiguous. /// Skips merge if T does not need drop. - pub(crate) fn try_merge(mut left: Self, right: Self) -> Self - { + pub(crate) fn try_merge(mut left: Self, right: Self) -> Self { if !std::mem::needs_drop::() { return left; } @@ -84,10 +77,8 @@ impl Partial unsafe impl Send for Partial where T: Send {} -impl Drop for Partial -{ - fn drop(&mut self) - { +impl Drop for Partial { + fn drop(&mut self) { if !self.ptr.is_null() { unsafe { ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut(self.ptr, self.len)); diff --git a/src/prelude.rs b/src/prelude.rs index 072eb482..fcaa1839 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -19,16 +19,7 @@ #[doc(no_inline)] pub use crate::{ - ArcArray, - Array, - ArrayBase, - ArrayRef, - ArrayView, - ArrayViewMut, - CowArray, - LayoutRef, - RawArrayView, - RawArrayViewMut, + ArcArray, Array, ArrayBase, ArrayRef, ArrayView, ArrayViewMut, CowArray, LayoutRef, RawArrayView, RawArrayViewMut, RawRef, }; @@ -46,13 +37,7 @@ pub use crate::{ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, Arra #[doc(no_inline)] pub use crate::{ - ArrayViewMut0, - ArrayViewMut1, - ArrayViewMut2, - ArrayViewMut3, - ArrayViewMut4, - ArrayViewMut5, - ArrayViewMut6, + ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5, ArrayViewMut6, ArrayViewMutD, }; diff --git a/src/shape_builder.rs b/src/shape_builder.rs index 50c00b04..c42ea910 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -6,8 +6,7 @@ use crate::Dimension; /// /// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). #[derive(Copy, Clone, Debug)] -pub struct Shape -{ +pub struct Shape { /// Shape (axis lengths) pub(crate) dim: D, /// Strides can only be C or F here @@ -17,41 +16,36 @@ pub struct Shape #[derive(Copy, Clone, Debug)] pub(crate) enum Contiguous {} -impl Shape -{ - pub(crate) fn is_c(&self) -> bool - { +impl Shape { + pub(crate) fn is_c(&self) -> bool { matches!(self.strides, Strides::C) } } /// An array shape of n dimensions in c-order, f-order or custom strides. #[derive(Copy, Clone, Debug)] -pub struct StrideShape -{ +pub struct StrideShape { pub(crate) dim: D, pub(crate) strides: Strides, } impl StrideShape -where D: Dimension +where + D: Dimension, { /// Return a reference to the dimension - pub fn raw_dim(&self) -> &D - { + pub fn raw_dim(&self) -> &D { &self.dim } /// Return the size of the shape in number of elements - pub fn size(&self) -> usize - { + pub fn size(&self) -> usize { self.dim.size() } } /// Stride description #[derive(Copy, Clone, Debug)] -pub(crate) enum Strides -{ +pub(crate) enum Strides { /// Row-major ("C"-order) C, /// Column-major ("F"-order) @@ -60,11 +54,11 @@ pub(crate) enum Strides Custom(D), } -impl Strides -{ +impl Strides { /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride) pub(crate) fn strides_for_dim(self, dim: &D) -> D - where D: Dimension + where + D: Dimension, { match self { Strides::C => dim.default_strides(), @@ -83,8 +77,7 @@ impl Strides } #[inline] - pub(crate) fn is_custom(&self) -> bool - { + pub(crate) fn is_custom(&self) -> bool { matches!(*self, Strides::Custom(_)) } } @@ -94,8 +87,7 @@ impl Strides /// /// This trait is used together with array constructor methods like /// `Array::from_shape_vec`. -pub trait ShapeBuilder -{ +pub trait ShapeBuilder { /// The type that captures the built shape's dimensionality. type Dim: Dimension; @@ -116,11 +108,11 @@ pub trait ShapeBuilder } impl From for Shape -where D: Dimension +where + D: Dimension, { /// Create a `Shape` from `dimension`, using the default memory layout. - fn from(dimension: D) -> Shape - { + fn from(dimension: D) -> Shape { dimension.into_shape_with_order() } } @@ -130,8 +122,7 @@ where D: Dimension, T: ShapeBuilder, { - fn from(value: T) -> Self - { + fn from(value: T) -> Self { let shape = value.into_shape_with_order(); let st = if shape.is_c() { Strides::C } else { Strides::F }; StrideShape { @@ -142,55 +133,49 @@ where } impl ShapeBuilder for T -where T: IntoDimension +where + T: IntoDimension, { type Dim = T::Dim; type Strides = T; - fn into_shape_with_order(self) -> Shape - { + fn into_shape_with_order(self) -> Shape { Shape { dim: self.into_dimension(), strides: Strides::C, } } - fn f(self) -> Shape - { + fn f(self) -> Shape { self.set_f(true) } - fn set_f(self, is_f: bool) -> Shape - { + fn set_f(self, is_f: bool) -> Shape { self.into_shape_with_order().set_f(is_f) } - fn strides(self, st: T) -> StrideShape - { + fn strides(self, st: T) -> StrideShape { self.into_shape_with_order().strides(st.into_dimension()) } } impl ShapeBuilder for Shape -where D: Dimension +where + D: Dimension, { type Dim = D; type Strides = D; - fn into_shape_with_order(self) -> Shape - { + fn into_shape_with_order(self) -> Shape { self } - fn f(self) -> Self - { + fn f(self) -> Self { self.set_f(true) } - fn set_f(mut self, is_f: bool) -> Self - { + fn set_f(mut self, is_f: bool) -> Self { self.strides = if !is_f { Strides::C } else { Strides::F }; self } - fn strides(self, st: D) -> StrideShape - { + fn strides(self, st: D) -> StrideShape { StrideShape { dim: self.dim, strides: Strides::Custom(st), @@ -199,16 +184,15 @@ where D: Dimension } impl Shape -where D: Dimension +where + D: Dimension, { /// Return a reference to the dimension - pub fn raw_dim(&self) -> &D - { + pub fn raw_dim(&self) -> &D { &self.dim } /// Return the size of the shape in number of elements - pub fn size(&self) -> usize - { + pub fn size(&self) -> usize { self.dim.size() } } @@ -221,8 +205,7 @@ where D: Dimension /// (optionally) an ordering argument. /// /// See for example [`.to_shape()`](crate::ArrayRef::to_shape). -pub trait ShapeArg -{ +pub trait ShapeArg { /// The type that captures the shape's dimensionality. type Dim: Dimension; @@ -231,23 +214,23 @@ pub trait ShapeArg } impl ShapeArg for T -where T: IntoDimension +where + T: IntoDimension, { type Dim = T::Dim; - fn into_shape_and_order(self) -> (Self::Dim, Option) - { + fn into_shape_and_order(self) -> (Self::Dim, Option) { (self.into_dimension(), None) } } impl ShapeArg for (T, Order) -where T: IntoDimension +where + T: IntoDimension, { type Dim = T::Dim; - fn into_shape_and_order(self) -> (Self::Dim, Option) - { + fn into_shape_and_order(self) -> (Self::Dim, Option) { (self.0.into_dimension(), Some(self.1)) } } diff --git a/src/simd.rs b/src/simd.rs index f2878e8d..eb75bfa0 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -44,14 +44,20 @@ impl Tier { fn detect_tier() -> Tier { #[cfg(all(feature = "std", target_arch = "x86_64"))] { - if is_x86_feature_detected!("avx512f") { return Tier::Avx512; } - if is_x86_feature_detected!("avx2") { return Tier::Avx2; } + if is_x86_feature_detected!("avx512f") { + return Tier::Avx512; + } + if is_x86_feature_detected!("avx2") { + return Tier::Avx2; + } } #[cfg(all(feature = "std", target_arch = "aarch64"))] { // NEON is mandatory on aarch64 — always available. // dotprod (ARMv8.2+) distinguishes Pi 5 from Pi 3/4. - if std::arch::is_aarch64_feature_detected!("dotprod") { return Tier::NeonDotProd; } + if std::arch::is_aarch64_feature_detected!("dotprod") { + return Tier::NeonDotProd; + } return Tier::Neon; } #[cfg(all(not(feature = "std"), target_arch = "aarch64"))] @@ -83,7 +89,9 @@ static TIER: LazyLock = LazyLock::new(detect_tier); #[cfg(feature = "std")] #[inline(always)] #[allow(dead_code)] -fn tier() -> Tier { *TIER } +fn tier() -> Tier { + *TIER +} // ── no_std path: portable-atomic + critical-section polyfill ──────── #[cfg(all(not(feature = "std"), feature = "portable-atomic-critical-section"))] @@ -111,7 +119,9 @@ fn tier() -> Tier { #[cfg(all(not(feature = "std"), not(feature = "portable-atomic-critical-section")))] #[inline(always)] #[allow(dead_code)] -fn tier() -> Tier { detect_tier() } +fn tier() -> Tier { + detect_tier() +} // BF16 tier detection happens inline in bf16_to_f32_batch() via // is_x86_feature_detected!("avx512bf16") — no LazyLock needed. @@ -190,22 +200,42 @@ pub const PREFERRED_I16_LANES: usize = 16; #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] pub use crate::simd_avx512::{ - // 256-bit (AVX2 baseline, __m256/__m256d/__m256i) - F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16, + f32x16, + f32x8, + f64x4, + f64x8, + i16x16, + i16x32, + i32x16, + i64x8, + i8x32, + i8x64, + u32x16, + u64x8, + u8x64, + F32Mask16, // 512-bit (native AVX-512, __m512/__m512d/__m512i) - F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, - I8x64, I16x32, - F32Mask16, F64Mask8, - f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8, - i8x64, i16x32, + F32x16, + // 256-bit (AVX2 baseline, __m256/__m256d/__m256i) + F32x8, + F64Mask8, + F64x4, + F64x8, + I16x16, + I16x32, + I32x16, + I64x8, + I8x32, + I8x64, + U16x32, + U32x16, + U64x8, + U8x64, }; // BF16 types + batch conversion (always available — scalar fallback built in) #[cfg(target_arch = "x86_64")] -pub use crate::simd_avx512::{ - bf16_to_f32_scalar, f32_to_bf16_scalar, - bf16_to_f32_batch, f32_to_bf16_batch, -}; +pub use crate::simd_avx512::{bf16_to_f32_batch, bf16_to_f32_scalar, f32_to_bf16_batch, f32_to_bf16_scalar}; // BF16 RNE (round-to-nearest-even) path — pure AVX-512-F, byte-exact vs // hardware `_mm512_cvtneps_pbh` on Sapphire Rapids+ (verified on 1M inputs @@ -216,24 +246,18 @@ pub use crate::simd_avx512::{ // loops per the workspace-wide "never scalar ever" rule for F32→BF16. // See lance-graph/CLAUDE.md § Certification Process. #[cfg(target_arch = "x86_64")] -pub use crate::simd_avx512::{ - f32_to_bf16_scalar_rne, - f32_to_bf16_batch_rne, -}; +pub use crate::simd_avx512::{f32_to_bf16_batch_rne, f32_to_bf16_scalar_rne}; // BF16 SIMD types only available when avx512bf16 is enabled at compile time #[cfg(all(target_arch = "x86_64", target_feature = "avx512bf16"))] pub use crate::simd_avx512::{BF16x16, BF16x8}; #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] -pub use crate::simd_avx512::{F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16}; +pub use crate::simd_avx512::{f32x8, f64x4, i16x16, i8x32, F32x8, F64x4, I16x16, I8x32}; #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] pub use crate::simd_avx2::{ - F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, - I8x64, I16x32, - F32Mask16, F64Mask8, - f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8, - i8x64, i16x32, + f32x16, f64x8, i16x32, i32x16, i64x8, i8x64, u32x16, u64x8, u8x64, F32Mask16, F32x16, F64Mask8, F64x8, I16x32, + I32x16, I64x8, I8x64, U16x32, U32x16, U64x8, U8x64, }; // ============================================================================ @@ -244,8 +268,8 @@ pub use crate::simd_avx2::{ pub(crate) mod scalar { use core::fmt; use core::ops::{ - Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, - Div, DivAssign, Mul, MulAssign, Neg, Not, Shl, Shr, Sub, SubAssign, + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, + Neg, Not, Shl, Shr, Sub, SubAssign, }; // ── Macros for scalar fallback boilerplate ──────────────────────── @@ -258,14 +282,18 @@ pub(crate) mod scalar { impl Default for $name { #[inline(always)] - fn default() -> Self { Self([0.0; $lanes]) } + fn default() -> Self { + Self([0.0; $lanes]) + } } impl $name { pub const LANES: usize = $lanes; #[inline(always)] - pub fn splat(v: $elem) -> Self { Self([v; $lanes]) } + pub fn splat(v: $elem) -> Self { + Self([v; $lanes]) + } #[inline(always)] pub fn from_slice(s: &[$elem]) -> Self { @@ -276,10 +304,14 @@ pub(crate) mod scalar { } #[inline(always)] - pub fn from_array(arr: [$elem; $lanes]) -> Self { Self(arr) } + pub fn from_array(arr: [$elem; $lanes]) -> Self { + Self(arr) + } #[inline(always)] - pub fn to_array(self) -> [$elem; $lanes] { self.0 } + pub fn to_array(self) -> [$elem; $lanes] { + self.0 + } #[inline(always)] pub fn copy_to_slice(self, s: &mut [$elem]) { @@ -288,7 +320,9 @@ pub(crate) mod scalar { } #[inline(always)] - pub fn reduce_sum(self) -> $elem { self.0.iter().sum() } + pub fn reduce_sum(self) -> $elem { + self.0.iter().sum() + } #[inline(always)] pub fn reduce_min(self) -> $elem { @@ -297,20 +331,27 @@ pub(crate) mod scalar { #[inline(always)] pub fn reduce_max(self) -> $elem { - self.0.iter().copied().fold(<$elem>::NEG_INFINITY, <$elem>::max) + self.0 + .iter() + .copied() + .fold(<$elem>::NEG_INFINITY, <$elem>::max) } #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].min(other.0[i]); } + for i in 0..$lanes { + out[i] = self.0[i].min(other.0[i]); + } Self(out) } #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].max(other.0[i]); } + for i in 0..$lanes { + out[i] = self.0[i].max(other.0[i]); + } Self(out) } @@ -322,69 +363,99 @@ pub(crate) mod scalar { #[inline(always)] pub fn mul_add(self, b: Self, c: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].mul_add(b.0[i], c.0[i]); } + for i in 0..$lanes { + out[i] = self.0[i].mul_add(b.0[i], c.0[i]); + } Self(out) } #[inline(always)] pub fn sqrt(self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].sqrt(); } + for i in 0..$lanes { + out[i] = self.0[i].sqrt(); + } Self(out) } #[inline(always)] pub fn round(self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].round(); } + for i in 0..$lanes { + out[i] = self.0[i].round(); + } Self(out) } #[inline(always)] pub fn floor(self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].floor(); } + for i in 0..$lanes { + out[i] = self.0[i].floor(); + } Self(out) } #[inline(always)] pub fn abs(self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].abs(); } + for i in 0..$lanes { + out[i] = self.0[i].abs(); + } Self(out) } #[inline(always)] pub fn simd_lt(self, other: Self) -> $mask { let mut bits: $mask_prim = 0; - for i in 0..$lanes { if self.0[i] < other.0[i] { bits |= 1 << i; } } + for i in 0..$lanes { + if self.0[i] < other.0[i] { + bits |= 1 << i; + } + } $mask(bits) } #[inline(always)] pub fn simd_le(self, other: Self) -> $mask { let mut bits: $mask_prim = 0; - for i in 0..$lanes { if self.0[i] <= other.0[i] { bits |= 1 << i; } } + for i in 0..$lanes { + if self.0[i] <= other.0[i] { + bits |= 1 << i; + } + } $mask(bits) } #[inline(always)] - pub fn simd_gt(self, other: Self) -> $mask { other.simd_lt(self) } + pub fn simd_gt(self, other: Self) -> $mask { + other.simd_lt(self) + } #[inline(always)] - pub fn simd_ge(self, other: Self) -> $mask { other.simd_le(self) } + pub fn simd_ge(self, other: Self) -> $mask { + other.simd_le(self) + } #[inline(always)] pub fn simd_eq(self, other: Self) -> $mask { let mut bits: $mask_prim = 0; - for i in 0..$lanes { if self.0[i] == other.0[i] { bits |= 1 << i; } } + for i in 0..$lanes { + if self.0[i] == other.0[i] { + bits |= 1 << i; + } + } $mask(bits) } #[inline(always)] pub fn simd_ne(self, other: Self) -> $mask { let mut bits: $mask_prim = 0; - for i in 0..$lanes { if self.0[i] != other.0[i] { bits |= 1 << i; } } + for i in 0..$lanes { + if self.0[i] != other.0[i] { + bits |= 1 << i; + } + } $mask(bits) } } @@ -394,7 +465,9 @@ pub(crate) mod scalar { #[inline(always)] fn add(self, rhs: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] + rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] + rhs.0[i]; + } Self(out) } } @@ -403,7 +476,9 @@ pub(crate) mod scalar { #[inline(always)] fn sub(self, rhs: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] - rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] - rhs.0[i]; + } Self(out) } } @@ -412,7 +487,9 @@ pub(crate) mod scalar { #[inline(always)] fn mul(self, rhs: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] * rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] * rhs.0[i]; + } Self(out) } } @@ -421,32 +498,52 @@ pub(crate) mod scalar { #[inline(always)] fn div(self, rhs: Self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] / rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] / rhs.0[i]; + } Self(out) } } impl AddAssign for $name { #[inline(always)] - fn add_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] += rhs.0[i]; } } + fn add_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] += rhs.0[i]; + } + } } impl SubAssign for $name { #[inline(always)] - fn sub_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] -= rhs.0[i]; } } + fn sub_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] -= rhs.0[i]; + } + } } impl MulAssign for $name { #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] *= rhs.0[i]; } } + fn mul_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] *= rhs.0[i]; + } + } } impl DivAssign for $name { #[inline(always)] - fn div_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] /= rhs.0[i]; } } + fn div_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] /= rhs.0[i]; + } + } } impl Neg for $name { type Output = Self; #[inline(always)] fn neg(self) -> Self { let mut out = [0.0 as $elem; $lanes]; - for i in 0..$lanes { out[i] = -self.0[i]; } + for i in 0..$lanes { + out[i] = -self.0[i]; + } Self(out) } } @@ -456,7 +553,9 @@ pub(crate) mod scalar { } } impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { self.0 == other.0 } + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } } // Mask type @@ -468,7 +567,11 @@ pub(crate) mod scalar { pub fn select(self, true_val: $name, false_val: $name) -> $name { let mut out = [0.0 as $elem; $lanes]; for i in 0..$lanes { - out[i] = if (self.0 >> i) & 1 == 1 { true_val.0[i] } else { false_val.0[i] }; + out[i] = if (self.0 >> i) & 1 == 1 { + true_val.0[i] + } else { + false_val.0[i] + }; } $name(out) } @@ -484,14 +587,18 @@ pub(crate) mod scalar { impl Default for $name { #[inline(always)] - fn default() -> Self { Self([$zero; $lanes]) } + fn default() -> Self { + Self([$zero; $lanes]) + } } impl $name { pub const LANES: usize = $lanes; #[inline(always)] - pub fn splat(v: $elem) -> Self { Self([v; $lanes]) } + pub fn splat(v: $elem) -> Self { + Self([v; $lanes]) + } #[inline(always)] pub fn from_slice(s: &[$elem]) -> Self { @@ -502,10 +609,14 @@ pub(crate) mod scalar { } #[inline(always)] - pub fn from_array(arr: [$elem; $lanes]) -> Self { Self(arr) } + pub fn from_array(arr: [$elem; $lanes]) -> Self { + Self(arr) + } #[inline(always)] - pub fn to_array(self) -> [$elem; $lanes] { self.0 } + pub fn to_array(self) -> [$elem; $lanes] { + self.0 + } #[inline(always)] pub fn copy_to_slice(self, s: &mut [$elem]) { @@ -516,7 +627,9 @@ pub(crate) mod scalar { #[inline(always)] pub fn reduce_sum(self) -> $elem { let mut s: $elem = $zero; - for i in 0..$lanes { s = s.wrapping_add(self.0[i]); } + for i in 0..$lanes { + s = s.wrapping_add(self.0[i]); + } s } } @@ -526,7 +639,9 @@ pub(crate) mod scalar { #[inline(always)] fn add(self, rhs: Self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].wrapping_add(rhs.0[i]); } + for i in 0..$lanes { + out[i] = self.0[i].wrapping_add(rhs.0[i]); + } Self(out) } } @@ -535,20 +650,26 @@ pub(crate) mod scalar { #[inline(always)] fn sub(self, rhs: Self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = self.0[i].wrapping_sub(rhs.0[i]); } + for i in 0..$lanes { + out[i] = self.0[i].wrapping_sub(rhs.0[i]); + } Self(out) } } impl AddAssign for $name { #[inline(always)] fn add_assign(&mut self, rhs: Self) { - for i in 0..$lanes { self.0[i] = self.0[i].wrapping_add(rhs.0[i]); } + for i in 0..$lanes { + self.0[i] = self.0[i].wrapping_add(rhs.0[i]); + } } } impl SubAssign for $name { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { - for i in 0..$lanes { self.0[i] = self.0[i].wrapping_sub(rhs.0[i]); } + for i in 0..$lanes { + self.0[i] = self.0[i].wrapping_sub(rhs.0[i]); + } } } impl BitAnd for $name { @@ -556,7 +677,9 @@ pub(crate) mod scalar { #[inline(always)] fn bitand(self, rhs: Self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] & rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] & rhs.0[i]; + } Self(out) } } @@ -565,7 +688,9 @@ pub(crate) mod scalar { #[inline(always)] fn bitor(self, rhs: Self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] | rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] | rhs.0[i]; + } Self(out) } } @@ -574,28 +699,44 @@ pub(crate) mod scalar { #[inline(always)] fn bitxor(self, rhs: Self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = self.0[i] ^ rhs.0[i]; } + for i in 0..$lanes { + out[i] = self.0[i] ^ rhs.0[i]; + } Self(out) } } impl BitAndAssign for $name { #[inline(always)] - fn bitand_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] &= rhs.0[i]; } } + fn bitand_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] &= rhs.0[i]; + } + } } impl BitOrAssign for $name { #[inline(always)] - fn bitor_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] |= rhs.0[i]; } } + fn bitor_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] |= rhs.0[i]; + } + } } impl BitXorAssign for $name { #[inline(always)] - fn bitxor_assign(&mut self, rhs: Self) { for i in 0..$lanes { self.0[i] ^= rhs.0[i]; } } + fn bitxor_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self.0[i] ^= rhs.0[i]; + } + } } impl Not for $name { type Output = Self; #[inline(always)] fn not(self) -> Self { let mut out = [$zero; $lanes]; - for i in 0..$lanes { out[i] = !self.0[i]; } + for i in 0..$lanes { + out[i] = !self.0[i]; + } Self(out) } } @@ -605,7 +746,9 @@ pub(crate) mod scalar { } } impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { self.0 == other.0 } + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } } }; } @@ -642,50 +785,194 @@ pub(crate) mod scalar { // I8x64 / I8x32 / I16x32 / I16x16 — AVX-512BW-style methods (scalar shape) impl I8x64 { - #[inline(always)] pub fn zero() -> Self { Self([0i8; 64]) } - #[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } - #[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) } - #[inline(always)] pub fn cmp_gt(self, other: Self) -> u64 { + #[inline(always)] + pub fn zero() -> Self { + Self([0i8; 64]) + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn cmp_gt(self, other: Self) -> u64 { let mut m: u64 = 0; - for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } } + for i in 0..64 { + if self.0[i] > other.0[i] { + m |= 1u64 << i; + } + } m } } impl I8x32 { - #[inline(always)] pub fn zero() -> Self { Self([0i8; 32]) } - #[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } - #[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) } - #[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 { + #[inline(always)] + pub fn zero() -> Self { + Self([0i8; 32]) + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + let mut o = [0i8; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + let mut o = [0i8; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let mut o = [0i8; 32]; + for i in 0..32 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let mut o = [0i8; 32]; + for i in 0..32 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn cmp_gt(self, other: Self) -> u32 { let mut m: u32 = 0; - for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } } + for i in 0..32 { + if self.0[i] > other.0[i] { + m |= 1u32 << i; + } + } m } } impl I16x32 { - #[inline(always)] pub fn zero() -> Self { Self([0i16; 32]) } - #[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } - #[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) } - #[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 { + #[inline(always)] + pub fn zero() -> Self { + Self([0i16; 32]) + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn cmp_gt(self, other: Self) -> u32 { let mut m: u32 = 0; - for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } } + for i in 0..32 { + if self.0[i] > other.0[i] { + m |= 1u32 << i; + } + } m } } impl I16x16 { - #[inline(always)] pub fn zero() -> Self { Self([0i16; 16]) } - #[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } - #[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].max(other.0[i]); } Self(o) } - #[inline(always)] pub fn cmp_gt(self, other: Self) -> u16 { + #[inline(always)] + pub fn zero() -> Self { + Self([0i16; 16]) + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + let mut o = [0i16; 16]; + for i in 0..16 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + let mut o = [0i16; 16]; + for i in 0..16 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let mut o = [0i16; 16]; + for i in 0..16 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let mut o = [0i16; 16]; + for i in 0..16 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn cmp_gt(self, other: Self) -> u16 { let mut m: u16 = 0; - for i in 0..16 { if self.0[i] > other.0[i] { m |= 1u16 << i; } } + for i in 0..16 { + if self.0[i] > other.0[i] { + m |= 1u16 << i; + } + } m } } @@ -694,80 +981,124 @@ pub(crate) mod scalar { impl U16x32 { #[inline(always)] pub fn from_u8x64_lo(v: U8x64) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = v.0[i] as u16; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = v.0[i] as u16; + } + Self(out) } #[inline(always)] pub fn from_u8x64_hi(v: U8x64) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = v.0[32 + i] as u16; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = v.0[32 + i] as u16; + } + Self(out) } #[inline(always)] pub fn pack_saturate_u8(self, other: Self) -> U8x64 { let mut out = [0u8; 64]; - for i in 0..32 { out[i] = self.0[i].min(255) as u8; } - for i in 0..32 { out[32 + i] = other.0[i].min(255) as u8; } + for i in 0..32 { + out[i] = self.0[i].min(255) as u8; + } + for i in 0..32 { + out[32 + i] = other.0[i].min(255) as u8; + } U8x64(out) } #[inline(always)] pub fn shr(self, imm: u32) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = if imm < 16 { self.0[i] >> imm } else { 0 }; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = if imm < 16 { self.0[i] >> imm } else { 0 }; + } + Self(out) } #[inline(always)] pub fn shl(self, imm: u32) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = if imm < 16 { self.0[i] << imm } else { 0 }; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = if imm < 16 { self.0[i] << imm } else { 0 }; + } + Self(out) } #[inline(always)] pub fn mullo(self, other: Self) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = self.0[i].wrapping_mul(other.0[i]); } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = self.0[i].wrapping_mul(other.0[i]); + } + Self(out) } } // Extra methods for I32x16 that float types have via the macro impl I32x16 { #[inline(always)] - pub fn reduce_min(self) -> i32 { *self.0.iter().min().unwrap_or(&0) } + pub fn reduce_min(self) -> i32 { + *self.0.iter().min().unwrap_or(&0) + } #[inline(always)] - pub fn reduce_max(self) -> i32 { *self.0.iter().max().unwrap_or(&0) } + pub fn reduce_max(self) -> i32 { + *self.0.iter().max().unwrap_or(&0) + } #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = self.0[i].min(other.0[i]); } + for i in 0..16 { + out[i] = self.0[i].min(other.0[i]); + } Self(out) } #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = self.0[i].max(other.0[i]); } + for i in 0..16 { + out[i] = self.0[i].max(other.0[i]); + } Self(out) } #[inline(always)] pub fn cast_f32(self) -> F32x16 { let mut out = [0.0f32; 16]; - for i in 0..16 { out[i] = self.0[i] as f32; } + for i in 0..16 { + out[i] = self.0[i] as f32; + } F32x16(out) } #[inline(always)] pub fn abs(self) -> Self { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = self.0[i].abs(); } + for i in 0..16 { + out[i] = self.0[i].abs(); + } Self(out) } #[inline(always)] pub fn from_i16_slice(s: &[i16]) -> Self { assert!(s.len() >= 16); let mut o = [0i32; 16]; - for i in 0..16 { o[i] = s[i] as i32; } + for i in 0..16 { + o[i] = s[i] as i32; + } Self(o) } #[inline(always)] pub fn to_i16_array(self) -> [i16; 16] { let mut o = [0i16; 16]; - for i in 0..16 { o[i] = self.0[i] as i16; } + for i in 0..16 { + o[i] = self.0[i] as i16; + } o } #[inline(always)] pub fn cmpge_zero_mask(self) -> u16 { let mut mask = 0u16; - for i in 0..16 { if self.0[i] >= 0 { mask |= 1 << i; } } + for i in 0..16 { + if self.0[i] >= 0 { + mask |= 1 << i; + } + } mask } } @@ -777,20 +1108,26 @@ pub(crate) mod scalar { #[inline(always)] fn mul(self, rhs: Self) -> Self { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = self.0[i].wrapping_mul(rhs.0[i]); } + for i in 0..16 { + out[i] = self.0[i].wrapping_mul(rhs.0[i]); + } Self(out) } } impl MulAssign for I32x16 { #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } } impl Neg for I32x16 { type Output = Self; #[inline(always)] fn neg(self) -> Self { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = -self.0[i]; } + for i in 0..16 { + out[i] = -self.0[i]; + } Self(out) } } @@ -800,19 +1137,25 @@ pub(crate) mod scalar { #[inline(always)] pub fn to_bits(self) -> U32x16 { let mut out = [0u32; 16]; - for i in 0..16 { out[i] = self.0[i].to_bits(); } + for i in 0..16 { + out[i] = self.0[i].to_bits(); + } U32x16(out) } #[inline(always)] pub fn from_bits(bits: U32x16) -> Self { let mut out = [0.0f32; 16]; - for i in 0..16 { out[i] = f32::from_bits(bits.0[i]); } + for i in 0..16 { + out[i] = f32::from_bits(bits.0[i]); + } Self(out) } #[inline(always)] pub fn cast_i32(self) -> I32x16 { let mut out = [0i32; 16]; - for i in 0..16 { out[i] = self.0[i] as i32; } + for i in 0..16 { + out[i] = self.0[i] as i32; + } I32x16(out) } } @@ -822,13 +1165,17 @@ pub(crate) mod scalar { #[inline(always)] pub fn to_bits(self) -> U64x8 { let mut out = [0u64; 8]; - for i in 0..8 { out[i] = self.0[i].to_bits(); } + for i in 0..8 { + out[i] = self.0[i].to_bits(); + } U64x8(out) } #[inline(always)] pub fn from_bits(bits: U64x8) -> Self { let mut out = [0.0f64; 8]; - for i in 0..8 { out[i] = f64::from_bits(bits.0[i]); } + for i in 0..8 { + out[i] = f64::from_bits(bits.0[i]); + } Self(out) } } @@ -836,25 +1183,35 @@ pub(crate) mod scalar { // Extra for I64x8 impl I64x8 { #[inline(always)] - pub fn reduce_min(self) -> i64 { *self.0.iter().min().unwrap_or(&0) } + pub fn reduce_min(self) -> i64 { + *self.0.iter().min().unwrap_or(&0) + } #[inline(always)] - pub fn reduce_max(self) -> i64 { *self.0.iter().max().unwrap_or(&0) } + pub fn reduce_max(self) -> i64 { + *self.0.iter().max().unwrap_or(&0) + } #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut out = [0i64; 8]; - for i in 0..8 { out[i] = self.0[i].min(other.0[i]); } + for i in 0..8 { + out[i] = self.0[i].min(other.0[i]); + } Self(out) } #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut out = [0i64; 8]; - for i in 0..8 { out[i] = self.0[i].max(other.0[i]); } + for i in 0..8 { + out[i] = self.0[i].max(other.0[i]); + } Self(out) } #[inline(always)] pub fn abs(self) -> Self { let mut out = [0i64; 8]; - for i in 0..8 { out[i] = self.0[i].abs(); } + for i in 0..8 { + out[i] = self.0[i].abs(); + } Self(out) } } @@ -864,20 +1221,26 @@ pub(crate) mod scalar { #[inline(always)] fn mul(self, rhs: Self) -> Self { let mut out = [0i64; 8]; - for i in 0..8 { out[i] = self.0[i].wrapping_mul(rhs.0[i]); } + for i in 0..8 { + out[i] = self.0[i].wrapping_mul(rhs.0[i]); + } Self(out) } } impl MulAssign for I64x8 { #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } } impl Neg for I64x8 { type Output = Self; #[inline(always)] fn neg(self) -> Self { let mut out = [0i64; 8]; - for i in 0..8 { out[i] = -self.0[i]; } + for i in 0..8 { + out[i] = -self.0[i]; + } Self(out) } } @@ -888,7 +1251,9 @@ pub(crate) mod scalar { #[inline(always)] fn shr(self, rhs: Self) -> Self { let mut out = [0u32; 16]; - for i in 0..16 { out[i] = self.0[i] >> rhs.0[i]; } + for i in 0..16 { + out[i] = self.0[i] >> rhs.0[i]; + } Self(out) } } @@ -897,7 +1262,9 @@ pub(crate) mod scalar { #[inline(always)] fn shl(self, rhs: Self) -> Self { let mut out = [0u32; 16]; - for i in 0..16 { out[i] = self.0[i] << rhs.0[i]; } + for i in 0..16 { + out[i] = self.0[i] << rhs.0[i]; + } Self(out) } } @@ -908,7 +1275,9 @@ pub(crate) mod scalar { #[inline(always)] fn shr(self, rhs: Self) -> Self { let mut out = [0u64; 8]; - for i in 0..8 { out[i] = self.0[i] >> rhs.0[i]; } + for i in 0..8 { + out[i] = self.0[i] >> rhs.0[i]; + } Self(out) } } @@ -917,7 +1286,9 @@ pub(crate) mod scalar { #[inline(always)] fn shl(self, rhs: Self) -> Self { let mut out = [0u64; 8]; - for i in 0..8 { out[i] = self.0[i] << rhs.0[i]; } + for i in 0..8 { + out[i] = self.0[i] << rhs.0[i]; + } Self(out) } } @@ -928,33 +1299,53 @@ pub(crate) mod scalar { #[inline(always)] fn mul(self, rhs: Self) -> Self { let mut out = [0u8; 64]; - for i in 0..64 { out[i] = self.0[i].wrapping_mul(rhs.0[i]); } + for i in 0..64 { + out[i] = self.0[i].wrapping_mul(rhs.0[i]); + } Self(out) } } impl MulAssign for U8x64 { #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } } // U8x64 extra methods — byte-level operations for palette codec, nibble, byte scan impl U8x64 { #[inline(always)] - pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap_or(&0) } + pub fn reduce_min(self) -> u8 { + *self.0.iter().min().unwrap_or(&0) + } #[inline(always)] - pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap_or(&0) } + pub fn reduce_max(self) -> u8 { + *self.0.iter().max().unwrap_or(&0) + } #[inline(always)] pub fn simd_min(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].min(other.0[i]); } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[i].min(other.0[i]); + } + Self(out) } #[inline(always)] pub fn simd_max(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].max(other.0[i]); } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[i].max(other.0[i]); + } + Self(out) } #[inline(always)] pub fn cmpeq_mask(self, other: Self) -> u64 { let mut mask = 0u64; - for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } } + for i in 0..64 { + if self.0[i] == other.0[i] { + mask |= 1u64 << i; + } + } mask } #[inline(always)] @@ -964,64 +1355,115 @@ pub(crate) mod scalar { let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]); let shifted = val >> imm; let bytes = shifted.to_le_bytes(); - out[i] = bytes[0]; out[i + 1] = bytes[1]; + out[i] = bytes[0]; + out[i + 1] = bytes[1]; } Self(out) } #[inline(always)] pub fn saturating_sub(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[i].saturating_sub(other.0[i]); + } + Self(out) } // ── Tier 1: seismon rasterizer primitives (scalar fallbacks) ── #[inline(always)] pub fn pairwise_avg(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = ((self.0[i] as u16 + other.0[i] as u16 + 1) >> 1) as u8; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = ((self.0[i] as u16 + other.0[i] as u16 + 1) >> 1) as u8; + } + Self(out) } #[inline(always)] pub fn cmpgt_mask(self, other: Self) -> u64 { - let mut m: u64 = 0; for i in 0..64 { if self.0[i] > other.0[i] { m |= 1 << i; } } m + let mut m: u64 = 0; + for i in 0..64 { + if self.0[i] > other.0[i] { + m |= 1 << i; + } + } + m } #[inline(always)] pub fn mask_blend(mask: u64, a: Self, b: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = if mask & (1 << i) != 0 { b.0[i] } else { a.0[i] }; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = if mask & (1 << i) != 0 { b.0[i] } else { a.0[i] }; + } + Self(out) } #[inline(always)] pub fn shl_epi16(self, imm: u32) -> Self { let mut out = [0u8; 64]; for i in (0..64).step_by(2) { - let v = u16::from_le_bytes([self.0[i], self.0[i+1]]); + let v = u16::from_le_bytes([self.0[i], self.0[i + 1]]); let s = if imm < 16 { v << imm } else { 0 }; - let b = s.to_le_bytes(); out[i] = b[0]; out[i+1] = b[1]; + let b = s.to_le_bytes(); + out[i] = b[0]; + out[i + 1] = b[1]; } Self(out) } // ── Tier 2: sprite blit + palette remap (scalar fallbacks) ── #[inline(always)] pub unsafe fn mask_store(self, ptr: *mut u8, mask: u64) { - for i in 0..64 { if mask & (1 << i) != 0 { *ptr.add(i) = self.0[i]; } } + for i in 0..64 { + if mask & (1 << i) != 0 { + *ptr.add(i) = self.0[i]; + } + } } #[inline(always)] pub fn saturating_add(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_add(other.0[i]); } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[i].saturating_add(other.0[i]); + } + Self(out) } #[inline(always)] pub fn permute_bytes(self, idx: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[(idx.0[i] & 63) as usize]; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[(idx.0[i] & 63) as usize]; + } + Self(out) } #[inline(always)] pub fn movemask(self) -> u64 { - let mut m: u64 = 0; for i in 0..64 { if self.0[i] & 0x80 != 0 { m |= 1 << i; } } m + let mut m: u64 = 0; + for i in 0..64 { + if self.0[i] & 0x80 != 0 { + m |= 1 << i; + } + } + m } #[inline(always)] pub fn unpack_lo_epi8(self, other: Self) -> Self { let mut out = [0u8; 64]; - for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+i]; out[b+i*2+1] = other.0[b+i]; } } + for lane in 0..4 { + let b = lane * 16; + for i in 0..8 { + out[b + i * 2] = self.0[b + i]; + out[b + i * 2 + 1] = other.0[b + i]; + } + } Self(out) } #[inline(always)] pub fn unpack_hi_epi8(self, other: Self) -> Self { let mut out = [0u8; 64]; - for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+8+i]; out[b+i*2+1] = other.0[b+8+i]; } } + for lane in 0..4 { + let b = lane * 16; + for i in 0..8 { + out[b + i * 2] = self.0[b + 8 + i]; + out[b + i * 2 + 1] = other.0[b + 8 + i]; + } + } Self(out) } /// Byte-wise shuffle: use `self` as a LUT, `idx` selects bytes within each 128-bit (16-byte) lane. @@ -1044,9 +1486,11 @@ pub(crate) mod scalar { /// Build a nibble-popcount lookup table (replicated across 4 x 16-byte lanes). #[inline(always)] pub fn nibble_popcount_lut() -> Self { - let lane: [u8; 16] = [0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4]; + let lane: [u8; 16] = [0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4]; let mut arr = [0u8; 64]; - for l in 0..4 { arr[l*16..(l+1)*16].copy_from_slice(&lane); } + for l in 0..4 { + arr[l * 16..(l + 1) * 16].copy_from_slice(&lane); + } Self(arr) } } @@ -1057,25 +1501,40 @@ pub(crate) mod scalar { #[inline(always)] fn mul(self, rhs: Self) -> Self { let mut out = [0u32; 16]; - for i in 0..16 { out[i] = self.0[i].wrapping_mul(rhs.0[i]); } + for i in 0..16 { + out[i] = self.0[i].wrapping_mul(rhs.0[i]); + } Self(out) } } // Lowercase aliases - #[allow(non_camel_case_types)] pub type f32x16 = F32x16; - #[allow(non_camel_case_types)] pub type f64x8 = F64x8; - #[allow(non_camel_case_types)] pub type u8x64 = U8x64; - #[allow(non_camel_case_types)] pub type i32x16 = I32x16; - #[allow(non_camel_case_types)] pub type i64x8 = I64x8; - #[allow(non_camel_case_types)] pub type u32x16 = U32x16; - #[allow(non_camel_case_types)] pub type u64x8 = U64x8; - #[allow(non_camel_case_types)] pub type f32x8 = F32x8; - #[allow(non_camel_case_types)] pub type f64x4 = F64x4; - #[allow(non_camel_case_types)] pub type i8x64 = I8x64; - #[allow(non_camel_case_types)] pub type i8x32 = I8x32; - #[allow(non_camel_case_types)] pub type i16x32 = I16x32; - #[allow(non_camel_case_types)] pub type i16x16 = I16x16; + #[allow(non_camel_case_types)] + pub type f32x16 = F32x16; + #[allow(non_camel_case_types)] + pub type f64x8 = F64x8; + #[allow(non_camel_case_types)] + pub type u8x64 = U8x64; + #[allow(non_camel_case_types)] + pub type i32x16 = I32x16; + #[allow(non_camel_case_types)] + pub type i64x8 = I64x8; + #[allow(non_camel_case_types)] + pub type u32x16 = U32x16; + #[allow(non_camel_case_types)] + pub type u64x8 = U64x8; + #[allow(non_camel_case_types)] + pub type f32x8 = F32x8; + #[allow(non_camel_case_types)] + pub type f64x4 = F64x4; + #[allow(non_camel_case_types)] + pub type i8x64 = I8x64; + #[allow(non_camel_case_types)] + pub type i8x32 = I8x32; + #[allow(non_camel_case_types)] + pub type i16x32 = I16x32; + #[allow(non_camel_case_types)] + pub type i16x16 = I16x16; } // aarch64: F32x16/F64x8 come from the real NEON paired-load implementation @@ -1083,42 +1542,43 @@ pub(crate) mod scalar { // Integer + 256-bit float types still come from the scalar fallback; they're // not on the critical path for f32 BLAS-1 / VML kernels. #[cfg(target_arch = "aarch64")] -pub use crate::simd_neon::aarch64_simd::{ - F32x16, F64x8, F32Mask16, F64Mask8, - f32x16, f64x8, -}; +pub use crate::simd_neon::aarch64_simd::{f32x16, f64x8, F32Mask16, F32x16, F64Mask8, F64x8}; #[cfg(target_arch = "aarch64")] pub use scalar::{ - U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, - F32x8, F64x4, - u8x64, i32x16, i64x8, u32x16, u64x8, - f32x8, f64x4, + f32x8, f64x4, i32x16, i64x8, u32x16, u64x8, u8x64, F32x8, F64x4, I32x16, I64x8, U16x32, U32x16, U64x8, U8x64, }; // Other non-x86 targets (wasm, riscv, etc.): full scalar fallback. #[cfg(all(not(target_arch = "x86_64"), not(target_arch = "aarch64")))] pub use scalar::{ - F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, - F32x8, F64x4, - I8x64, I8x32, I16x32, I16x16, - F32Mask16, F64Mask8, - f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8, - f32x8, f64x4, - i8x64, i8x32, i16x32, i16x16, + f32x16, f32x8, f64x4, f64x8, i16x16, i16x32, i32x16, i64x8, i8x32, i8x64, u32x16, u64x8, u8x64, F32Mask16, F32x16, + F32x8, F64Mask8, F64x4, F64x8, I16x16, I16x32, I32x16, I64x8, I8x32, I8x64, U16x32, U32x16, U64x8, U8x64, }; // Scalar BF16 conversion — always available on all platforms #[cfg(not(target_arch = "x86_64"))] -pub fn bf16_to_f32_scalar(bits: u16) -> f32 { f32::from_bits((bits as u32) << 16) } +pub fn bf16_to_f32_scalar(bits: u16) -> f32 { + f32::from_bits((bits as u32) << 16) +} #[cfg(not(target_arch = "x86_64"))] -pub fn f32_to_bf16_scalar(v: f32) -> u16 { (v.to_bits() >> 16) as u16 } +pub fn f32_to_bf16_scalar(v: f32) -> u16 { + (v.to_bits() >> 16) as u16 +} #[cfg(not(target_arch = "x86_64"))] pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { - for (i, &b) in input.iter().enumerate() { if i < output.len() { output[i] = bf16_to_f32_scalar(b); } } + for (i, &b) in input.iter().enumerate() { + if i < output.len() { + output[i] = bf16_to_f32_scalar(b); + } + } } #[cfg(not(target_arch = "x86_64"))] pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { - for (i, &v) in input.iter().enumerate() { if i < output.len() { output[i] = f32_to_bf16_scalar(v); } } + for (i, &v) in input.iter().enumerate() { + if i < output.len() { + output[i] = f32_to_bf16_scalar(v); + } + } } // ============================================================================ @@ -1188,21 +1648,15 @@ pub fn simd_ln_f32(x: F32x16) -> F32x16 { // Without `hpc-extras`, consumers still get the SIMD polyfill types above // (F32x16, I8x32, etc.) but NOT the domain-specific functions below. -pub use crate::hpc::fingerprint::{ - Fingerprint, - Fingerprint2K, Fingerprint1K, Fingerprint64K, - VectorWidth, VectorConfig, vector_config, -}; +pub use crate::hpc::bitwise::{hamming_distance_raw, popcount_raw}; pub use crate::hpc::bnn_cross_plane::CollapseGate; -pub use crate::hpc::bitwise::{ - hamming_distance_raw, popcount_raw, -}; pub use crate::hpc::fft::{wht_f32, wht_f32_new}; +pub use crate::hpc::fingerprint::{ + vector_config, Fingerprint, Fingerprint1K, Fingerprint2K, Fingerprint64K, VectorConfig, VectorWidth, +}; pub use crate::hpc::quantized::{ - quantize_f32_to_i4, dequantize_i4_to_f32, - quantize_f32_to_i2, dequantize_i2_to_f32, - quantize_f32_to_i8, dequantize_i8_to_f32, - QuantParams, + dequantize_i2_to_f32, dequantize_i4_to_f32, dequantize_i8_to_f32, quantize_f32_to_i2, quantize_f32_to_i4, + quantize_f32_to_i8, QuantParams, }; // Half-precision SIMD vectors (BF16x16, F16x16) — portable scalar impl, always @@ -1216,17 +1670,14 @@ pub use crate::hpc::quantized::{ // Always re-export F16x16 + all slice-level ops (no naming conflict). #[cfg(feature = "std")] pub use crate::simd_half::{ - F16x16, - add_bf16_inplace, mul_bf16_inplace, - add_f16_inplace, mul_f16_inplace, - cast_bf16_to_f32_batch, cast_f16_to_f32_batch, - cast_f32_to_bf16_batch, cast_f32_to_f16_batch, + add_bf16_inplace, add_f16_inplace, cast_bf16_to_f32_batch, cast_f16_to_f32_batch, cast_f32_to_bf16_batch, + cast_f32_to_f16_batch, mul_bf16_inplace, mul_f16_inplace, F16x16, }; // Re-export portable BF16x16 only when the hardware-native avx512bf16 variant // is NOT active (otherwise `simd_avx512::BF16x16` already occupies the name). #[cfg(all(feature = "std", not(all(target_arch = "x86_64", target_feature = "avx512bf16"))))] -pub use crate::simd_half::BF16x16 as BF16x16; +pub use crate::simd_half::BF16x16; // K-means + L2 distance @@ -1239,10 +1690,8 @@ pub use crate::hpc::heel_f64x8::cosine_f32_to_f64_simd; // Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail). #[cfg(feature = "std")] pub use crate::simd_ops::{ - add_f32, sub_f32, mul_f32, div_f32, - add_f32_inplace, sub_f32_inplace, mul_f32_inplace, div_f32_inplace, - scale_f32, add_scalar_f32, scale_f32_inplace, - add_f64, mul_f64, add_f64_inplace, + add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_scalar_f32, div_f32, div_f32_inplace, mul_f32, + mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace, }; // ============================================================================ @@ -1261,8 +1710,7 @@ mod tests { #[test] fn f32x16_from_array_roundtrip() { - let data: [f32; 16] = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, - 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]; + let data: [f32; 16] = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]; let v = F32x16::from_array(data); assert_eq!(v.to_array(), data); } @@ -1288,10 +1736,8 @@ mod tests { #[test] fn f32x16_mask_select() { - let a = F32x16::from_array([ - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, - 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, - ]); + let a = + F32x16::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); let threshold = F32x16::splat(8.5); let mask = a.simd_lt(threshold); let result = mask.select(F32x16::splat(1.0), F32x16::splat(0.0)); diff --git a/src/simd_amx.rs b/src/simd_amx.rs index 9bc688af..2e41857d 100644 --- a/src/simd_amx.rs +++ b/src/simd_amx.rs @@ -50,12 +50,16 @@ pub fn amx_available() -> bool { let cpuid = core::arch::x86_64::__cpuid_count(7, 0); let amx_tile = (cpuid.edx >> 24) & 1; let amx_int8 = (cpuid.edx >> 25) & 1; - if amx_tile == 0 || amx_int8 == 0 { return false; } + if amx_tile == 0 || amx_int8 == 0 { + return false; + } // Step 2: OS enabled XSAVE? (CPUID.01H:ECX bit 27 = OSXSAVE) let cpuid_01 = core::arch::x86_64::__cpuid(1); let osxsave = (cpuid_01.ecx >> 27) & 1; - if osxsave == 0 { return false; } + if osxsave == 0 { + return false; + } // Step 3: OS actually enabled tile state in XCR0? // _xgetbv(0) reads the ACTUAL XCR0 register (what the OS set), @@ -64,7 +68,9 @@ pub fn amx_available() -> bool { let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) }; let tilecfg = (xcr0 >> 17) & 1; let tiledata = (xcr0 >> 18) & 1; - if tilecfg == 0 || tiledata == 0 { return false; } + if tilecfg == 0 || tiledata == 0 { + return false; + } // Step 4: Request XCOMP_PERM for TILEDATA. // Linux kernel 5.19+: processes must call prctl(ARCH_REQ_XCOMP_PERM, 18) @@ -95,14 +101,18 @@ pub fn amx_available() -> bool { options(nostack), ); } - if ret != 0 { return false; } + if ret != 0 { + return false; + } } true } #[cfg(not(target_arch = "x86_64"))] -pub fn amx_available() -> bool { false } +pub fn amx_available() -> bool { + false +} /// AMX capability report. pub fn amx_report() -> String { @@ -115,7 +125,9 @@ pub fn amx_report() -> String { format!("AMX: TILE={} INT8={} BF16={} available={}", tile, int8, bf16, amx_available()) } #[cfg(not(target_arch = "x86_64"))] - { "AMX: not x86_64".to_string() } + { + "AMX: not x86_64".to_string() + } } // ═══════════════════════════════════════════════════════════════════════════ @@ -132,8 +144,8 @@ pub fn amx_report() -> String { #[target_feature(enable = "avx512vnni")] pub unsafe fn vnni_dpbusd( acc: core::arch::x86_64::__m512i, - a: core::arch::x86_64::__m512i, // 64 × u8 - b: core::arch::x86_64::__m512i, // 64 × i8 (energy, quantized) + a: core::arch::x86_64::__m512i, // 64 × u8 + b: core::arch::x86_64::__m512i, // 64 × i8 (energy, quantized) ) -> core::arch::x86_64::__m512i { core::arch::x86_64::_mm512_dpbusd_epi32(acc, a, b) } @@ -173,14 +185,12 @@ pub unsafe fn vnni_dot_u8_i8(row: &[u8], energy: &[i8]) -> i32 { /// This IS the ThinkingEngine's core loop at VNNI resolution. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512vnni")] -pub unsafe fn vnni_matvec( - table: &[u8], - energy_i8: &[i8], - result: &mut [i32], - n: usize, -) { +pub unsafe fn vnni_matvec(table: &[u8], energy_i8: &[i8], result: &mut [i32], n: usize) { for i in 0..n { - if energy_i8.iter().all(|&e| e == 0) { result[i] = 0; continue; } + if energy_i8.iter().all(|&e| e == 0) { + result[i] = 0; + continue; + } let row = &table[i * n..(i + 1) * n]; result[i] = vnni_dot_u8_i8(row, energy_i8); } @@ -223,12 +233,7 @@ pub unsafe fn vnni2_dot_u8_i8(row: &[u8], energy: &[i8]) -> i32 { /// VNNI2 MatVec for the entire distance table × energy vector (ymm path). #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avxvnniint8")] -pub unsafe fn vnni2_matvec( - table: &[u8], - energy_i8: &[i8], - result: &mut [i32], - n: usize, -) { +pub unsafe fn vnni2_matvec(table: &[u8], energy_i8: &[i8], result: &mut [i32], n: usize) { for i in 0..n { let row = &table[i * n..(i + 1) * n]; result[i] = vnni2_dot_u8_i8(row, energy_i8); @@ -246,12 +251,7 @@ pub fn vnni_dot_u8_i8_scalar(row: &[u8], energy: &[i8]) -> i32 { } /// Scalar MatVec fallback. -pub fn vnni_matvec_scalar( - table: &[u8], - energy_i8: &[i8], - result: &mut [i32], - n: usize, -) { +pub fn vnni_matvec_scalar(table: &[u8], energy_i8: &[i8], result: &mut [i32], n: usize) { for i in 0..n { let row = &table[i * n..(i + 1) * n]; result[i] = vnni_dot_u8_i8_scalar(row, energy_i8); @@ -279,20 +279,19 @@ pub fn vnni_matvec_scalar( /// The thinking engine's cycle_auto() dispatches: /// VNNI detected → cycle_vnni() → this function /// No VNNI → cycle() → F32x16 FMA (never reaches here) -pub fn matvec_dispatch( - table: &[u8], - energy_i8: &[i8], - result: &mut [i32], - n: usize, -) { +pub fn matvec_dispatch(table: &[u8], energy_i8: &[i8], result: &mut [i32], n: usize) { #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx512vnni") { - unsafe { vnni_matvec(table, energy_i8, result, n); } + unsafe { + vnni_matvec(table, energy_i8, result, n); + } return; } if is_x86_feature_detected!("avxvnniint8") { - unsafe { vnni2_matvec(table, energy_i8, result, n); } + unsafe { + vnni2_matvec(table, energy_i8, result, n); + } return; } } @@ -311,7 +310,9 @@ pub fn quantize_energy_i8(energy: &[f64], output: &mut [i8]) { let n = energy.len().min(output.len()); let max_e = energy.iter().cloned().fold(0.0f64, f64::max); if max_e < 1e-15 { - for o in output[..n].iter_mut() { *o = 0; } + for o in output[..n].iter_mut() { + *o = 0; + } return; } let scale = 127.0 / max_e; @@ -343,7 +344,7 @@ mod tests { #[test] fn test_vnni_dot_scalar() { - let row = vec![128u8; 64]; // similarity = 0.5 + let row = vec![128u8; 64]; // similarity = 0.5 let energy = vec![10i8; 64]; // energy = 10 let dot = vnni_dot_u8_i8_scalar(&row, &energy); assert_eq!(dot, 128 * 10 * 64); @@ -354,7 +355,9 @@ mod tests { fn test_vnni_matvec_scalar() { let n = 64; let mut table = vec![128u8; n * n]; - for i in 0..n { table[i * n + i] = 255; } // diagonal = max + for i in 0..n { + table[i * n + i] = 255; + } // diagonal = max let energy = vec![10i8; n]; let mut result = vec![0i32; n]; @@ -369,7 +372,9 @@ mod tests { fn test_vnni_dispatch() { let n = 64; let mut table = vec![128u8; n * n]; - for i in 0..n { table[i * n + i] = 255; } + for i in 0..n { + table[i * n + i] = 255; + } let energy = vec![10i8; n]; let mut result = vec![0i32; n]; @@ -396,7 +401,7 @@ mod tests { #[test] fn test_vnni_matches_scalar() { let n = 128; - let table: Vec = (0..n*n).map(|i| (i % 256) as u8).collect(); + let table: Vec = (0..n * n).map(|i| (i % 256) as u8).collect(); let energy: Vec = (0..n).map(|i| (i % 100) as i8).collect(); let mut scalar_result = vec![0i32; n]; @@ -406,8 +411,11 @@ mod tests { matvec_dispatch(&table, &energy, &mut dispatch_result, n); for i in 0..n { - assert_eq!(scalar_result[i], dispatch_result[i], - "mismatch at row {}: scalar={} dispatch={}", i, scalar_result[i], dispatch_result[i]); + assert_eq!( + scalar_result[i], dispatch_result[i], + "mismatch at row {}: scalar={} dispatch={}", + i, scalar_result[i], dispatch_result[i] + ); } } } diff --git a/src/simd_avx2.rs b/src/simd_avx2.rs index 116563d3..0f7799fe 100644 --- a/src/simd_avx2.rs +++ b/src/simd_avx2.rs @@ -10,7 +10,7 @@ use crate::simd_avx512::{f32x8, f64x4}; // AVX2-native I8x32 / I16x16 live in simd_avx512.rs (256-bit __m256i types). // Re-export so consumers see a unified `crate::simd_avx2::I8x32` symbol. -pub use crate::simd_avx512::{I8x32, I16x16, i8x32, i16x16}; +pub use crate::simd_avx512::{i16x16, i8x32, I16x16, I8x32}; // ============================================================================ // AVX2 lane counts (half of AVX-512) @@ -67,12 +67,9 @@ pub fn dot_f32(a: &[f32], b: &[f32]) -> f32 { for i in 0..full_iters { let base = i * 4 * F32_LANES; acc0 += f32x8::from_slice(&a[base..]) * f32x8::from_slice(&b[base..]); - acc1 += - f32x8::from_slice(&a[base + F32_LANES..]) * f32x8::from_slice(&b[base + F32_LANES..]); - acc2 += f32x8::from_slice(&a[base + 2 * F32_LANES..]) - * f32x8::from_slice(&b[base + 2 * F32_LANES..]); - acc3 += f32x8::from_slice(&a[base + 3 * F32_LANES..]) - * f32x8::from_slice(&b[base + 3 * F32_LANES..]); + acc1 += f32x8::from_slice(&a[base + F32_LANES..]) * f32x8::from_slice(&b[base + F32_LANES..]); + acc2 += f32x8::from_slice(&a[base + 2 * F32_LANES..]) * f32x8::from_slice(&b[base + 2 * F32_LANES..]); + acc3 += f32x8::from_slice(&a[base + 3 * F32_LANES..]) * f32x8::from_slice(&b[base + 3 * F32_LANES..]); } for i in (full_iters * 4)..chunks { @@ -102,12 +99,9 @@ pub fn dot_f64(a: &[f64], b: &[f64]) -> f64 { for i in 0..full_iters { let base = i * 4 * F64_LANES; acc0 += f64x4::from_slice(&a[base..]) * f64x4::from_slice(&b[base..]); - acc1 += - f64x4::from_slice(&a[base + F64_LANES..]) * f64x4::from_slice(&b[base + F64_LANES..]); - acc2 += f64x4::from_slice(&a[base + 2 * F64_LANES..]) - * f64x4::from_slice(&b[base + 2 * F64_LANES..]); - acc3 += f64x4::from_slice(&a[base + 3 * F64_LANES..]) - * f64x4::from_slice(&b[base + 3 * F64_LANES..]); + acc1 += f64x4::from_slice(&a[base + F64_LANES..]) * f64x4::from_slice(&b[base + F64_LANES..]); + acc2 += f64x4::from_slice(&a[base + 2 * F64_LANES..]) * f64x4::from_slice(&b[base + 2 * F64_LANES..]); + acc3 += f64x4::from_slice(&a[base + 3 * F64_LANES..]) * f64x4::from_slice(&b[base + 3 * F64_LANES..]); } for i in (full_iters * 4)..chunks { @@ -328,20 +322,10 @@ pub fn hamming_batch(query: &[u8], database: &[u8], num_rows: usize, row_bytes: let full = num_rows / 4; for i in 0..full { let base = i * 4; - distances[base] = - hamming_distance(query, &database[base * row_bytes..(base + 1) * row_bytes]); - distances[base + 1] = hamming_distance( - query, - &database[(base + 1) * row_bytes..(base + 2) * row_bytes], - ); - distances[base + 2] = hamming_distance( - query, - &database[(base + 2) * row_bytes..(base + 3) * row_bytes], - ); - distances[base + 3] = hamming_distance( - query, - &database[(base + 3) * row_bytes..(base + 4) * row_bytes], - ); + distances[base] = hamming_distance(query, &database[base * row_bytes..(base + 1) * row_bytes]); + distances[base + 1] = hamming_distance(query, &database[(base + 1) * row_bytes..(base + 2) * row_bytes]); + distances[base + 2] = hamming_distance(query, &database[(base + 2) * row_bytes..(base + 3) * row_bytes]); + distances[base + 3] = hamming_distance(query, &database[(base + 3) * row_bytes..(base + 4) * row_bytes]); } for i in (full * 4)..num_rows { distances[i] = hamming_distance(query, &database[i * row_bytes..(i + 1) * row_bytes]); @@ -352,11 +336,7 @@ pub fn hamming_batch(query: &[u8], database: &[u8], num_rows: usize, row_bytes: /// Top-k nearest neighbors by Hamming distance. pub fn hamming_top_k( - query: &[u8], - database: &[u8], - num_rows: usize, - row_bytes: usize, - k: usize, + query: &[u8], database: &[u8], num_rows: usize, row_bytes: usize, k: usize, ) -> (Vec, Vec) { let distances = hamming_batch(query, database, num_rows, row_bytes); let k = k.min(num_rows); @@ -380,8 +360,7 @@ pub fn popcount(a: &[u8]) -> u64 { let chunks = len / 32; let low_mask = _mm256_set1_epi8(0x0f); let lookup = _mm256_setr_epi8( - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, ); let mut total = _mm256_setzero_si256(); let blocks = chunks / 8; @@ -392,10 +371,7 @@ pub fn popcount(a: &[u8]) -> u64 { let v = _mm256_loadu_si256(a[idx..].as_ptr() as *const __m256i); let lo = _mm256_and_si256(v, low_mask); let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask); - let cnt = _mm256_add_epi8( - _mm256_shuffle_epi8(lookup, lo), - _mm256_shuffle_epi8(lookup, hi), - ); + let cnt = _mm256_add_epi8(_mm256_shuffle_epi8(lookup, lo), _mm256_shuffle_epi8(lookup, hi)); local = _mm256_add_epi8(local, cnt); } total = _mm256_add_epi64(total, _mm256_sad_epu8(local, _mm256_setzero_si256())); @@ -407,10 +383,7 @@ pub fn popcount(a: &[u8]) -> u64 { let v = _mm256_loadu_si256(a[idx..].as_ptr() as *const __m256i); let lo = _mm256_and_si256(v, low_mask); let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask); - let cnt = _mm256_add_epi8( - _mm256_shuffle_epi8(lookup, lo), - _mm256_shuffle_epi8(lookup, hi), - ); + let cnt = _mm256_add_epi8(_mm256_shuffle_epi8(lookup, lo), _mm256_shuffle_epi8(lookup, hi)); local = _mm256_add_epi8(local, cnt); } total = _mm256_add_epi64(total, _mm256_sad_epu8(local, _mm256_setzero_si256())); @@ -469,7 +442,10 @@ pub fn dot_i8(a: &[u8], b: &[u8]) -> i64 { } #[cfg(not(target_arch = "x86_64"))] { - a.iter().zip(b.iter()).map(|(&x, &y)| (x as i8 as i64) * (y as i8 as i64)).sum() + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| (x as i8 as i64) * (y as i8 as i64)) + .sum() } } @@ -484,10 +460,7 @@ pub fn dot_i8(a: &[u8], b: &[u8]) -> i64 { /// sufficient as the AVX2 fallback tier. #[allow(clippy::too_many_arguments)] pub fn sgemm_blocked( - m: usize, n: usize, k: usize, - alpha: f32, a: &[f32], lda: usize, - b: &[f32], ldb: usize, - c: &mut [f32], ldc: usize, + m: usize, n: usize, k: usize, alpha: f32, a: &[f32], lda: usize, b: &[f32], ldb: usize, c: &mut [f32], ldc: usize, ) { // Scalar fallback: row-by-row dot products for i in 0..m { @@ -504,10 +477,7 @@ pub fn sgemm_blocked( /// AVX2 blocked DGEMM fallback — delegates to scalar implementation. #[allow(clippy::too_many_arguments)] pub fn dgemm_blocked( - m: usize, n: usize, k: usize, - alpha: f64, a: &[f64], lda: usize, - b: &[f64], ldb: usize, - c: &mut [f64], ldc: usize, + m: usize, n: usize, k: usize, alpha: f64, a: &[f64], lda: usize, b: &[f64], ldb: usize, c: &mut [f64], ldc: usize, ) { // Scalar fallback: row-by-row dot products for i in 0..m { @@ -529,8 +499,10 @@ pub fn dgemm_blocked( // ============================================================================ use core::fmt; -use core::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign, Div, DivAssign, Neg, - BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not}; +use core::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, + Neg, Not, Sub, SubAssign, +}; /// 16×f32 via 2× AVX2 F32x8 (__m256). Same API as simd_avx512::F32x16. #[derive(Copy, Clone)] @@ -539,77 +511,167 @@ pub struct F32x16(pub f32x8, pub f32x8); impl F32x16 { pub const LANES: usize = 16; - #[inline(always)] pub fn splat(v: f32) -> Self { Self(f32x8::splat(v), f32x8::splat(v)) } - #[inline(always)] pub fn from_slice(s: &[f32]) -> Self { + #[inline(always)] + pub fn splat(v: f32) -> Self { + Self(f32x8::splat(v), f32x8::splat(v)) + } + #[inline(always)] + pub fn from_slice(s: &[f32]) -> Self { assert!(s.len() >= 16); Self(f32x8::from_slice(&s[..8]), f32x8::from_slice(&s[8..16])) } - #[inline(always)] pub fn from_array(a: [f32; 16]) -> Self { + #[inline(always)] + pub fn from_array(a: [f32; 16]) -> Self { Self(f32x8::from_array(a[..8].try_into().unwrap()), f32x8::from_array(a[8..].try_into().unwrap())) } - #[inline(always)] pub fn to_array(self) -> [f32; 16] { + #[inline(always)] + pub fn to_array(self) -> [f32; 16] { let mut out = [0.0f32; 16]; out[..8].copy_from_slice(&self.0.to_array()); out[8..].copy_from_slice(&self.1.to_array()); out } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [f32]) { + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f32]) { assert!(s.len() >= 16); self.0.copy_to_slice(&mut s[..8]); self.1.copy_to_slice(&mut s[8..16]); } - #[inline(always)] pub fn reduce_sum(self) -> f32 { self.0.reduce_sum() + self.1.reduce_sum() } - #[inline(always)] pub fn reduce_min(self) -> f32 { + #[inline(always)] + pub fn reduce_sum(self) -> f32 { + self.0.reduce_sum() + self.1.reduce_sum() + } + #[inline(always)] + pub fn reduce_min(self) -> f32 { let a = self.to_array(); a.iter().copied().fold(f32::INFINITY, f32::min) } - #[inline(always)] pub fn reduce_max(self) -> f32 { + #[inline(always)] + pub fn reduce_max(self) -> f32 { let a = self.to_array(); a.iter().copied().fold(f32::NEG_INFINITY, f32::max) } - #[inline(always)] pub fn abs(self) -> Self { Self(self.0.abs(), self.1.abs()) } - #[inline(always)] pub fn sqrt(self) -> Self { + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs(), self.1.abs()) + } + #[inline(always)] + pub fn sqrt(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].sqrt(); + } + Self::from_array(o) + } + #[inline(always)] + pub fn round(self) -> Self { let a = self.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].sqrt(); } Self::from_array(o) + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].round(); + } + Self::from_array(o) } - #[inline(always)] pub fn round(self) -> Self { + #[inline(always)] + pub fn floor(self) -> Self { let a = self.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].round(); } Self::from_array(o) + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].floor(); + } + Self::from_array(o) } - #[inline(always)] pub fn floor(self) -> Self { + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { let a = self.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].floor(); } Self::from_array(o) + let ba = b.to_array(); + let ca = c.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].mul_add(ba[i], ca[i]); + } + Self::from_array(o) } - #[inline(always)] pub fn mul_add(self, b: Self, c: Self) -> Self { - let a = self.to_array(); let ba = b.to_array(); let ca = c.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].mul_add(ba[i], ca[i]); } Self::from_array(o) + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].min(b[i]); + } + Self::from_array(o) } - #[inline(always)] pub fn simd_min(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].min(b[i]); } Self::from_array(o) + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = a[i].max(b[i]); + } + Self::from_array(o) } - #[inline(always)] pub fn simd_max(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = a[i].max(b[i]); } Self::from_array(o) + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + self.simd_max(lo).simd_min(hi) } - #[inline(always)] pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } - #[inline(always)] pub fn simd_lt(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u16 = 0; for i in 0..16 { if a[i] < b[i] { bits |= 1 << i; } } F32Mask16(bits) + #[inline(always)] + pub fn simd_lt(self, other: Self) -> F32Mask16 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { + if a[i] < b[i] { + bits |= 1 << i; + } + } + F32Mask16(bits) } - #[inline(always)] pub fn simd_le(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u16 = 0; for i in 0..16 { if a[i] <= b[i] { bits |= 1 << i; } } F32Mask16(bits) + #[inline(always)] + pub fn simd_le(self, other: Self) -> F32Mask16 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { + if a[i] <= b[i] { + bits |= 1 << i; + } + } + F32Mask16(bits) } - #[inline(always)] pub fn simd_gt(self, other: Self) -> F32Mask16 { other.simd_lt(self) } - #[inline(always)] pub fn simd_ge(self, other: Self) -> F32Mask16 { other.simd_le(self) } - #[inline(always)] pub fn simd_eq(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u16 = 0; for i in 0..16 { if a[i] == b[i] { bits |= 1 << i; } } F32Mask16(bits) + #[inline(always)] + pub fn simd_gt(self, other: Self) -> F32Mask16 { + other.simd_lt(self) } - #[inline(always)] pub fn simd_ne(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u16 = 0; for i in 0..16 { if a[i] != b[i] { bits |= 1 << i; } } F32Mask16(bits) + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F32Mask16 { + other.simd_le(self) + } + #[inline(always)] + pub fn simd_eq(self, other: Self) -> F32Mask16 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { + if a[i] == b[i] { + bits |= 1 << i; + } + } + F32Mask16(bits) + } + #[inline(always)] + pub fn simd_ne(self, other: Self) -> F32Mask16 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { + if a[i] != b[i] { + bits |= 1 << i; + } + } + F32Mask16(bits) } /// Gather 16 f32 values from `base_ptr` using 16 i32 indices. /// @@ -619,43 +681,130 @@ impl F32x16 { pub unsafe fn gather(indices: I32x16, base_ptr: *const f32) -> Self { let idx = indices.0; let mut o = [0.0f32; 16]; - for i in 0..16 { o[i] = *base_ptr.add(idx[i] as usize); } + for i in 0..16 { + o[i] = *base_ptr.add(idx[i] as usize); + } Self::from_array(o) } - #[inline(always)] pub fn to_bits(self) -> U32x16 { + #[inline(always)] + pub fn to_bits(self) -> U32x16 { let a = self.to_array(); - let mut o = [0u32; 16]; for i in 0..16 { o[i] = a[i].to_bits(); } U32x16(o) + let mut o = [0u32; 16]; + for i in 0..16 { + o[i] = a[i].to_bits(); + } + U32x16(o) } - #[inline(always)] pub fn from_bits(bits: U32x16) -> Self { - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = f32::from_bits(bits.0[i]); } Self::from_array(o) + #[inline(always)] + pub fn from_bits(bits: U32x16) -> Self { + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = f32::from_bits(bits.0[i]); + } + Self::from_array(o) } - #[inline(always)] pub fn cast_i32(self) -> I32x16 { + #[inline(always)] + pub fn cast_i32(self) -> I32x16 { let a = self.to_array(); - let mut o = [0i32; 16]; for i in 0..16 { o[i] = a[i] as i32; } I32x16(o) + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = a[i] as i32; + } + I32x16(o) } } -impl Add for F32x16 { type Output = Self; #[inline(always)] fn add(self, rhs: Self) -> Self { Self(self.0 + rhs.0, self.1 + rhs.1) } } -impl Sub for F32x16 { type Output = Self; #[inline(always)] fn sub(self, rhs: Self) -> Self { Self(self.0 - rhs.0, self.1 - rhs.1) } } -impl Mul for F32x16 { type Output = Self; #[inline(always)] fn mul(self, rhs: Self) -> Self { Self(self.0 * rhs.0, self.1 * rhs.1) } } -impl Div for F32x16 { type Output = Self; #[inline(always)] fn div(self, rhs: Self) -> Self { Self(self.0 / rhs.0, self.1 / rhs.1) } } -impl AddAssign for F32x16 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl SubAssign for F32x16 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl MulAssign for F32x16 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl DivAssign for F32x16 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } -impl Neg for F32x16 { type Output = Self; #[inline(always)] fn neg(self) -> Self { let a = self.to_array(); let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = -a[i]; } Self::from_array(o) } } -impl fmt::Debug for F32x16 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "F32x16({:?})", self.to_array()) } } -impl PartialEq for F32x16 { fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } } -impl Default for F32x16 { fn default() -> Self { Self::splat(0.0) } } +impl Add for F32x16 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0, self.1 + rhs.1) + } +} +impl Sub for F32x16 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0, self.1 - rhs.1) + } +} +impl Mul for F32x16 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(self.0 * rhs.0, self.1 * rhs.1) + } +} +impl Div for F32x16 { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self(self.0 / rhs.0, self.1 / rhs.1) + } +} +impl AddAssign for F32x16 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl SubAssign for F32x16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl MulAssign for F32x16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl DivAssign for F32x16 { + #[inline(always)] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} +impl Neg for F32x16 { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = -a[i]; + } + Self::from_array(o) + } +} +impl fmt::Debug for F32x16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "F32x16({:?})", self.to_array()) + } +} +impl PartialEq for F32x16 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} +impl Default for F32x16 { + fn default() -> Self { + Self::splat(0.0) + } +} #[derive(Copy, Clone, Debug)] pub struct F32Mask16(pub u16); impl F32Mask16 { #[inline(always)] pub fn select(self, true_val: F32x16, false_val: F32x16) -> F32x16 { - let t = true_val.to_array(); let f = false_val.to_array(); + let t = true_val.to_array(); + let f = false_val.to_array(); let mut o = [0.0f32; 16]; - for i in 0..16 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + for i in 0..16 { + o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; + } F32x16::from_array(o) } } @@ -667,76 +816,251 @@ pub struct F64x8(pub f64x4, pub f64x4); impl F64x8 { pub const LANES: usize = 8; - #[inline(always)] pub fn splat(v: f64) -> Self { Self(f64x4::splat(v), f64x4::splat(v)) } - #[inline(always)] pub fn from_slice(s: &[f64]) -> Self { + #[inline(always)] + pub fn splat(v: f64) -> Self { + Self(f64x4::splat(v), f64x4::splat(v)) + } + #[inline(always)] + pub fn from_slice(s: &[f64]) -> Self { assert!(s.len() >= 8); Self(f64x4::from_slice(&s[..4]), f64x4::from_slice(&s[4..8])) } - #[inline(always)] pub fn from_array(a: [f64; 8]) -> Self { + #[inline(always)] + pub fn from_array(a: [f64; 8]) -> Self { Self(f64x4::from_array(a[..4].try_into().unwrap()), f64x4::from_array(a[4..].try_into().unwrap())) } - #[inline(always)] pub fn to_array(self) -> [f64; 8] { + #[inline(always)] + pub fn to_array(self) -> [f64; 8] { let mut out = [0.0f64; 8]; out[..4].copy_from_slice(&self.0.to_array()); out[4..].copy_from_slice(&self.1.to_array()); out } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [f64]) { + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f64]) { assert!(s.len() >= 8); self.0.copy_to_slice(&mut s[..4]); self.1.copy_to_slice(&mut s[4..8]); } - #[inline(always)] pub fn reduce_sum(self) -> f64 { self.0.reduce_sum() + self.1.reduce_sum() } - #[inline(always)] pub fn reduce_min(self) -> f64 { let a = self.to_array(); a.iter().copied().fold(f64::INFINITY, f64::min) } - #[inline(always)] pub fn reduce_max(self) -> f64 { let a = self.to_array(); a.iter().copied().fold(f64::NEG_INFINITY, f64::max) } - #[inline(always)] pub fn abs(self) -> Self { let a = self.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].abs(); } Self::from_array(o) } - #[inline(always)] pub fn sqrt(self) -> Self { let a = self.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].sqrt(); } Self::from_array(o) } - #[inline(always)] pub fn round(self) -> Self { let a = self.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].round(); } Self::from_array(o) } - #[inline(always)] pub fn floor(self) -> Self { let a = self.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].floor(); } Self::from_array(o) } - #[inline(always)] pub fn mul_add(self, b: Self, c: Self) -> Self { - let a = self.to_array(); let ba = b.to_array(); let ca = c.to_array(); - let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].mul_add(ba[i], ca[i]); } Self::from_array(o) - } - #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let a = self.to_array(); let b = other.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].min(b[i]); } Self::from_array(o) } - #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let a = self.to_array(); let b = other.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = a[i].max(b[i]); } Self::from_array(o) } - #[inline(always)] pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } - #[inline(always)] pub fn simd_ge(self, other: Self) -> F64Mask8 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u8 = 0; for i in 0..8 { if a[i] >= b[i] { bits |= 1 << i; } } F64Mask8(bits) - } - #[inline(always)] pub fn simd_le(self, other: Self) -> F64Mask8 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u8 = 0; for i in 0..8 { if a[i] <= b[i] { bits |= 1 << i; } } F64Mask8(bits) - } - #[inline(always)] pub fn to_bits(self) -> U64x8 { - let a = self.to_array(); let mut o = [0u64; 8]; for i in 0..8 { o[i] = a[i].to_bits(); } U64x8(o) - } - #[inline(always)] pub fn from_bits(bits: U64x8) -> Self { - let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = f64::from_bits(bits.0[i]); } Self::from_array(o) - } -} - -impl Add for F64x8 { type Output = Self; #[inline(always)] fn add(self, rhs: Self) -> Self { Self(self.0 + rhs.0, self.1 + rhs.1) } } -impl Sub for F64x8 { type Output = Self; #[inline(always)] fn sub(self, rhs: Self) -> Self { Self(self.0 - rhs.0, self.1 - rhs.1) } } -impl Mul for F64x8 { type Output = Self; #[inline(always)] fn mul(self, rhs: Self) -> Self { Self(self.0 * rhs.0, self.1 * rhs.1) } } -impl Div for F64x8 { type Output = Self; #[inline(always)] fn div(self, rhs: Self) -> Self { Self(self.0 / rhs.0, self.1 / rhs.1) } } -impl AddAssign for F64x8 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl SubAssign for F64x8 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl MulAssign for F64x8 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl DivAssign for F64x8 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } -impl Neg for F64x8 { type Output = Self; #[inline(always)] fn neg(self) -> Self { let a = self.to_array(); let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = -a[i]; } Self::from_array(o) } } -impl fmt::Debug for F64x8 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "F64x8({:?})", self.to_array()) } } -impl PartialEq for F64x8 { fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } } -impl Default for F64x8 { fn default() -> Self { Self::splat(0.0) } } + #[inline(always)] + pub fn reduce_sum(self) -> f64 { + self.0.reduce_sum() + self.1.reduce_sum() + } + #[inline(always)] + pub fn reduce_min(self) -> f64 { + let a = self.to_array(); + a.iter().copied().fold(f64::INFINITY, f64::min) + } + #[inline(always)] + pub fn reduce_max(self) -> f64 { + let a = self.to_array(); + a.iter().copied().fold(f64::NEG_INFINITY, f64::max) + } + #[inline(always)] + pub fn abs(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].abs(); + } + Self::from_array(o) + } + #[inline(always)] + pub fn sqrt(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].sqrt(); + } + Self::from_array(o) + } + #[inline(always)] + pub fn round(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].round(); + } + Self::from_array(o) + } + #[inline(always)] + pub fn floor(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].floor(); + } + Self::from_array(o) + } + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + let a = self.to_array(); + let ba = b.to_array(); + let ca = c.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].mul_add(ba[i], ca[i]); + } + Self::from_array(o) + } + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].min(b[i]); + } + Self::from_array(o) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = a[i].max(b[i]); + } + Self::from_array(o) + } + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + self.simd_max(lo).simd_min(hi) + } + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F64Mask8 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u8 = 0; + for i in 0..8 { + if a[i] >= b[i] { + bits |= 1 << i; + } + } + F64Mask8(bits) + } + #[inline(always)] + pub fn simd_le(self, other: Self) -> F64Mask8 { + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u8 = 0; + for i in 0..8 { + if a[i] <= b[i] { + bits |= 1 << i; + } + } + F64Mask8(bits) + } + #[inline(always)] + pub fn to_bits(self) -> U64x8 { + let a = self.to_array(); + let mut o = [0u64; 8]; + for i in 0..8 { + o[i] = a[i].to_bits(); + } + U64x8(o) + } + #[inline(always)] + pub fn from_bits(bits: U64x8) -> Self { + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = f64::from_bits(bits.0[i]); + } + Self::from_array(o) + } +} + +impl Add for F64x8 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0, self.1 + rhs.1) + } +} +impl Sub for F64x8 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0, self.1 - rhs.1) + } +} +impl Mul for F64x8 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(self.0 * rhs.0, self.1 * rhs.1) + } +} +impl Div for F64x8 { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self(self.0 / rhs.0, self.1 / rhs.1) + } +} +impl AddAssign for F64x8 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl SubAssign for F64x8 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl MulAssign for F64x8 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl DivAssign for F64x8 { + #[inline(always)] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} +impl Neg for F64x8 { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + let a = self.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = -a[i]; + } + Self::from_array(o) + } +} +impl fmt::Debug for F64x8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "F64x8({:?})", self.to_array()) + } +} +impl PartialEq for F64x8 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} +impl Default for F64x8 { + fn default() -> Self { + Self::splat(0.0) + } +} #[derive(Copy, Clone, Debug)] pub struct F64Mask8(pub u8); impl F64Mask8 { #[inline(always)] pub fn select(self, true_val: F64x8, false_val: F64x8) -> F64x8 { - let t = true_val.to_array(); let f = false_val.to_array(); + let t = true_val.to_array(); + let f = false_val.to_array(); let mut o = [0.0f64; 8]; - for i in 0..8 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + for i in 0..8 { + o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; + } F64x8::from_array(o) } } @@ -749,29 +1073,163 @@ macro_rules! avx2_int_type { #[repr(align(64))] pub struct $name(pub [$elem; $lanes]); - impl Default for $name { #[inline(always)] fn default() -> Self { Self([$zero; $lanes]) } } + impl Default for $name { + #[inline(always)] + fn default() -> Self { + Self([$zero; $lanes]) + } + } impl $name { pub const LANES: usize = $lanes; - #[inline(always)] pub fn splat(v: $elem) -> Self { Self([v; $lanes]) } - #[inline(always)] pub fn from_slice(s: &[$elem]) -> Self { assert!(s.len() >= $lanes); let mut a = [$zero; $lanes]; a.copy_from_slice(&s[..$lanes]); Self(a) } - #[inline(always)] pub fn from_array(a: [$elem; $lanes]) -> Self { Self(a) } - #[inline(always)] pub fn to_array(self) -> [$elem; $lanes] { self.0 } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [$elem]) { assert!(s.len() >= $lanes); s[..$lanes].copy_from_slice(&self.0); } - #[inline(always)] pub fn reduce_sum(self) -> $elem { let mut s: $elem = $zero; for i in 0..$lanes { s = s.wrapping_add(self.0[i]); } s } - } - impl Add for $name { type Output = Self; #[inline(always)] fn add(self, r: Self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = self.0[i].wrapping_add(r.0[i]); } Self(o) } } - impl Sub for $name { type Output = Self; #[inline(always)] fn sub(self, r: Self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = self.0[i].wrapping_sub(r.0[i]); } Self(o) } } - impl BitAnd for $name { type Output = Self; #[inline(always)] fn bitand(self, r: Self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = self.0[i] & r.0[i]; } Self(o) } } - impl BitOr for $name { type Output = Self; #[inline(always)] fn bitor(self, r: Self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = self.0[i] | r.0[i]; } Self(o) } } - impl BitXor for $name { type Output = Self; #[inline(always)] fn bitxor(self, r: Self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = self.0[i] ^ r.0[i]; } Self(o) } } - impl BitAndAssign for $name { #[inline(always)] fn bitand_assign(&mut self, r: Self) { for i in 0..$lanes { self.0[i] &= r.0[i]; } } } - impl BitOrAssign for $name { #[inline(always)] fn bitor_assign(&mut self, r: Self) { for i in 0..$lanes { self.0[i] |= r.0[i]; } } } - impl BitXorAssign for $name { #[inline(always)] fn bitxor_assign(&mut self, r: Self) { for i in 0..$lanes { self.0[i] ^= r.0[i]; } } } - impl Not for $name { type Output = Self; #[inline(always)] fn not(self) -> Self { let mut o = [$zero; $lanes]; for i in 0..$lanes { o[i] = !self.0[i]; } Self(o) } } - impl AddAssign for $name { #[inline(always)] fn add_assign(&mut self, r: Self) { for i in 0..$lanes { self.0[i] = self.0[i].wrapping_add(r.0[i]); } } } - impl SubAssign for $name { #[inline(always)] fn sub_assign(&mut self, r: Self) { for i in 0..$lanes { self.0[i] = self.0[i].wrapping_sub(r.0[i]); } } } - impl fmt::Debug for $name { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, concat!(stringify!($name), "({:?})"), &self.0[..]) } } - impl PartialEq for $name { fn eq(&self, other: &Self) -> bool { self.0 == other.0 } } + #[inline(always)] + pub fn splat(v: $elem) -> Self { + Self([v; $lanes]) + } + #[inline(always)] + pub fn from_slice(s: &[$elem]) -> Self { + assert!(s.len() >= $lanes); + let mut a = [$zero; $lanes]; + a.copy_from_slice(&s[..$lanes]); + Self(a) + } + #[inline(always)] + pub fn from_array(a: [$elem; $lanes]) -> Self { + Self(a) + } + #[inline(always)] + pub fn to_array(self) -> [$elem; $lanes] { + self.0 + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [$elem]) { + assert!(s.len() >= $lanes); + s[..$lanes].copy_from_slice(&self.0); + } + #[inline(always)] + pub fn reduce_sum(self) -> $elem { + let mut s: $elem = $zero; + for i in 0..$lanes { + s = s.wrapping_add(self.0[i]); + } + s + } + } + impl Add for $name { + type Output = Self; + #[inline(always)] + fn add(self, r: Self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = self.0[i].wrapping_add(r.0[i]); + } + Self(o) + } + } + impl Sub for $name { + type Output = Self; + #[inline(always)] + fn sub(self, r: Self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = self.0[i].wrapping_sub(r.0[i]); + } + Self(o) + } + } + impl BitAnd for $name { + type Output = Self; + #[inline(always)] + fn bitand(self, r: Self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = self.0[i] & r.0[i]; + } + Self(o) + } + } + impl BitOr for $name { + type Output = Self; + #[inline(always)] + fn bitor(self, r: Self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = self.0[i] | r.0[i]; + } + Self(o) + } + } + impl BitXor for $name { + type Output = Self; + #[inline(always)] + fn bitxor(self, r: Self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = self.0[i] ^ r.0[i]; + } + Self(o) + } + } + impl BitAndAssign for $name { + #[inline(always)] + fn bitand_assign(&mut self, r: Self) { + for i in 0..$lanes { + self.0[i] &= r.0[i]; + } + } + } + impl BitOrAssign for $name { + #[inline(always)] + fn bitor_assign(&mut self, r: Self) { + for i in 0..$lanes { + self.0[i] |= r.0[i]; + } + } + } + impl BitXorAssign for $name { + #[inline(always)] + fn bitxor_assign(&mut self, r: Self) { + for i in 0..$lanes { + self.0[i] ^= r.0[i]; + } + } + } + impl Not for $name { + type Output = Self; + #[inline(always)] + fn not(self) -> Self { + let mut o = [$zero; $lanes]; + for i in 0..$lanes { + o[i] = !self.0[i]; + } + Self(o) + } + } + impl AddAssign for $name { + #[inline(always)] + fn add_assign(&mut self, r: Self) { + for i in 0..$lanes { + self.0[i] = self.0[i].wrapping_add(r.0[i]); + } + } + } + impl SubAssign for $name { + #[inline(always)] + fn sub_assign(&mut self, r: Self) { + for i in 0..$lanes { + self.0[i] = self.0[i].wrapping_sub(r.0[i]); + } + } + } + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, concat!(stringify!($name), "({:?})"), &self.0[..]) + } + } + impl PartialEq for $name { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } }; } @@ -782,38 +1240,98 @@ avx2_int_type!(I16x32, i16, 32, 0i16); // I8x64 / I16x32: AVX2 scalar polyfill — methods matching the AVX-512BW API impl I8x64 { #[inline(always)] - pub fn zero() -> Self { Self([0i8; 64]) } + pub fn zero() -> Self { + Self([0i8; 64]) + } #[inline(always)] - pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } + pub fn add(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } + pub fn sub(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) } + pub fn min(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) } + pub fn max(self, other: Self) -> Self { + let mut o = [0i8; 64]; + for i in 0..64 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } #[inline(always)] pub fn cmp_gt(self, other: Self) -> u64 { let mut m: u64 = 0; - for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } } + for i in 0..64 { + if self.0[i] > other.0[i] { + m |= 1u64 << i; + } + } m } } impl I16x32 { #[inline(always)] - pub fn zero() -> Self { Self([0i16; 32]) } + pub fn zero() -> Self { + Self([0i16; 32]) + } #[inline(always)] - pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) } + pub fn add(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_add(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) } + pub fn sub(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) } + pub fn min(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } #[inline(always)] - pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) } + pub fn max(self, other: Self) -> Self { + let mut o = [0i16; 32]; + for i in 0..32 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } #[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 { let mut m: u32 = 0; - for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } } + for i in 0..32 { + if self.0[i] > other.0[i] { + m |= 1u32 << i; + } + } m } } @@ -825,7 +1343,11 @@ impl U8x64 { #[inline(always)] pub fn cmpeq_mask(self, other: Self) -> u64 { let mut mask = 0u64; - for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } } + for i in 0..64 { + if self.0[i] == other.0[i] { + mask |= 1u64 << i; + } + } mask } @@ -847,7 +1369,9 @@ impl U8x64 { #[inline(always)] pub fn saturating_sub(self, other: Self) -> Self { let mut out = [0u8; 64]; - for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } + for i in 0..64 { + out[i] = self.0[i].saturating_sub(other.0[i]); + } Self(out) } @@ -855,41 +1379,75 @@ impl U8x64 { #[inline(always)] pub fn pairwise_avg(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = ((self.0[i] as u16 + other.0[i] as u16 + 1) >> 1) as u8; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = ((self.0[i] as u16 + other.0[i] as u16 + 1) >> 1) as u8; + } + Self(out) } #[inline(always)] pub fn cmpgt_mask(self, other: Self) -> u64 { - let mut m: u64 = 0; for i in 0..64 { if self.0[i] > other.0[i] { m |= 1 << i; } } m + let mut m: u64 = 0; + for i in 0..64 { + if self.0[i] > other.0[i] { + m |= 1 << i; + } + } + m } #[inline(always)] pub fn mask_blend(mask: u64, a: Self, b: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = if mask & (1 << i) != 0 { b.0[i] } else { a.0[i] }; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = if mask & (1 << i) != 0 { b.0[i] } else { a.0[i] }; + } + Self(out) } #[inline(always)] pub fn shl_epi16(self, imm: u32) -> Self { let mut out = [0u8; 64]; for i in (0..64).step_by(2) { - let v = u16::from_le_bytes([self.0[i], self.0[i+1]]); + let v = u16::from_le_bytes([self.0[i], self.0[i + 1]]); let s = if imm < 16 { v << imm } else { 0 }; - let b = s.to_le_bytes(); out[i] = b[0]; out[i+1] = b[1]; + let b = s.to_le_bytes(); + out[i] = b[0]; + out[i + 1] = b[1]; } Self(out) } #[inline(always)] pub unsafe fn mask_store(self, ptr: *mut u8, mask: u64) { - for i in 0..64 { if mask & (1 << i) != 0 { *ptr.add(i) = self.0[i]; } } + for i in 0..64 { + if mask & (1 << i) != 0 { + *ptr.add(i) = self.0[i]; + } + } } #[inline(always)] pub fn saturating_add(self, other: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_add(other.0[i]); } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[i].saturating_add(other.0[i]); + } + Self(out) } #[inline(always)] pub fn permute_bytes(self, idx: Self) -> Self { - let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[(idx.0[i] & 63) as usize]; } Self(out) + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = self.0[(idx.0[i] & 63) as usize]; + } + Self(out) } #[inline(always)] pub fn movemask(self) -> u64 { - let mut m: u64 = 0; for i in 0..64 { if self.0[i] & 0x80 != 0 { m |= 1 << i; } } m + let mut m: u64 = 0; + for i in 0..64 { + if self.0[i] & 0x80 != 0 { + m |= 1 << i; + } + } + m } /// Interleave low bytes within each 128-bit lane. @@ -922,10 +1480,30 @@ impl U8x64 { } /// Reduce min/max (not in macro). - #[inline(always)] pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap() } - #[inline(always)] pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap() } - #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) } + #[inline(always)] + pub fn reduce_min(self) -> u8 { + *self.0.iter().min().unwrap() + } + #[inline(always)] + pub fn reduce_max(self) -> u8 { + *self.0.iter().max().unwrap() + } + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + let mut o = [0u8; 64]; + for i in 0..64 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let mut o = [0u8; 64]; + for i in 0..64 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } /// Byte-wise shuffle: use `self` as a LUT, `idx` selects bytes within each 16-byte lane. #[inline(always)] @@ -949,9 +1527,11 @@ impl U8x64 { /// Build a nibble-popcount lookup table (replicated across 4 x 16-byte lanes). #[inline(always)] pub fn nibble_popcount_lut() -> Self { - let lane: [u8; 16] = [0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4]; + let lane: [u8; 16] = [0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4]; let mut arr = [0u8; 64]; - for l in 0..4 { arr[l*16..(l+1)*16].copy_from_slice(&lane); } + for l in 0..4 { + arr[l * 16..(l + 1) * 16].copy_from_slice(&lane); + } Self(arr) } } @@ -966,47 +1546,107 @@ avx2_int_type!(U64x8, u64, 8, 0u64); impl U16x32 { #[inline(always)] pub fn from_u8x64_lo(v: U8x64) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = v.0[i] as u16; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = v.0[i] as u16; + } + Self(out) } #[inline(always)] pub fn from_u8x64_hi(v: U8x64) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = v.0[32 + i] as u16; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = v.0[32 + i] as u16; + } + Self(out) } #[inline(always)] pub fn pack_saturate_u8(self, other: Self) -> U8x64 { let mut out = [0u8; 64]; - for i in 0..32 { out[i] = self.0[i].min(255) as u8; } - for i in 0..32 { out[32 + i] = other.0[i].min(255) as u8; } + for i in 0..32 { + out[i] = self.0[i].min(255) as u8; + } + for i in 0..32 { + out[32 + i] = other.0[i].min(255) as u8; + } U8x64(out) } #[inline(always)] pub fn shr(self, imm: u32) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = if imm < 16 { self.0[i] >> imm } else { 0 }; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = if imm < 16 { self.0[i] >> imm } else { 0 }; + } + Self(out) } #[inline(always)] pub fn shl(self, imm: u32) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = if imm < 16 { self.0[i] << imm } else { 0 }; } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = if imm < 16 { self.0[i] << imm } else { 0 }; + } + Self(out) } #[inline(always)] pub fn mullo(self, other: Self) -> Self { - let mut out = [0u16; 32]; for i in 0..32 { out[i] = self.0[i].wrapping_mul(other.0[i]); } Self(out) + let mut out = [0u16; 32]; + for i in 0..32 { + out[i] = self.0[i].wrapping_mul(other.0[i]); + } + Self(out) } } impl I32x16 { - #[inline(always)] pub fn reduce_min(self) -> i32 { *self.0.iter().min().unwrap() } - #[inline(always)] pub fn reduce_max(self) -> i32 { *self.0.iter().max().unwrap() } - #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0i32; 16]; for i in 0..16 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0i32; 16]; for i in 0..16 { o[i] = self.0[i].max(other.0[i]); } Self(o) } - #[inline(always)] pub fn cast_f32(self) -> F32x16 { let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = self.0[i] as f32; } F32x16::from_array(o) } - #[inline(always)] pub fn abs(self) -> Self { let mut o = [0i32; 16]; for i in 0..16 { o[i] = self.0[i].abs(); } Self(o) } + #[inline(always)] + pub fn reduce_min(self) -> i32 { + *self.0.iter().min().unwrap() + } + #[inline(always)] + pub fn reduce_max(self) -> i32 { + *self.0.iter().max().unwrap() + } + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn cast_f32(self) -> F32x16 { + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = self.0[i] as f32; + } + F32x16::from_array(o) + } + #[inline(always)] + pub fn abs(self) -> Self { + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = self.0[i].abs(); + } + Self(o) + } /// Load 16 × i16, sign-extend to 16 × i32. #[inline(always)] pub fn from_i16_slice(s: &[i16]) -> Self { assert!(s.len() >= 16); let mut o = [0i32; 16]; - for i in 0..16 { o[i] = s[i] as i32; } + for i in 0..16 { + o[i] = s[i] as i32; + } Self(o) } @@ -1014,7 +1654,9 @@ impl I32x16 { #[inline(always)] pub fn to_i16_array(self) -> [i16; 16] { let mut o = [0i16; 16]; - for i in 0..16 { o[i] = self.0[i] as i16; } + for i in 0..16 { + o[i] = self.0[i] as i16; + } o } @@ -1022,19 +1664,68 @@ impl I32x16 { #[inline(always)] pub fn cmpge_zero_mask(self) -> u16 { let mut mask = 0u16; - for i in 0..16 { if self.0[i] >= 0 { mask |= 1 << i; } } + for i in 0..16 { + if self.0[i] >= 0 { + mask |= 1 << i; + } + } mask } } -impl Mul for I32x16 { type Output = Self; #[inline(always)] fn mul(self, r: Self) -> Self { let mut o = [0i32; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_mul(r.0[i]); } Self(o) } } -impl MulAssign for I32x16 { #[inline(always)] fn mul_assign(&mut self, r: Self) { *self = *self * r; } } -impl Neg for I32x16 { type Output = Self; #[inline(always)] fn neg(self) -> Self { let mut o = [0i32; 16]; for i in 0..16 { o[i] = -self.0[i]; } Self(o) } } +impl Mul for I32x16 { + type Output = Self; + #[inline(always)] + fn mul(self, r: Self) -> Self { + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = self.0[i].wrapping_mul(r.0[i]); + } + Self(o) + } +} +impl MulAssign for I32x16 { + #[inline(always)] + fn mul_assign(&mut self, r: Self) { + *self = *self * r; + } +} +impl Neg for I32x16 { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = -self.0[i]; + } + Self(o) + } +} impl I64x8 { - #[inline(always)] pub fn reduce_min(self) -> i64 { *self.0.iter().min().unwrap() } - #[inline(always)] pub fn reduce_max(self) -> i64 { *self.0.iter().max().unwrap() } - #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0i64; 8]; for i in 0..8 { o[i] = self.0[i].min(other.0[i]); } Self(o) } - #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0i64; 8]; for i in 0..8 { o[i] = self.0[i].max(other.0[i]); } Self(o) } + #[inline(always)] + pub fn reduce_min(self) -> i64 { + *self.0.iter().min().unwrap() + } + #[inline(always)] + pub fn reduce_max(self) -> i64 { + *self.0.iter().max().unwrap() + } + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + let mut o = [0i64; 8]; + for i in 0..8 { + o[i] = self.0[i].min(other.0[i]); + } + Self(o) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let mut o = [0i64; 8]; + for i in 0..8 { + o[i] = self.0[i].max(other.0[i]); + } + Self(o) + } } /// Lowercase aliases (std::simd convention). @@ -1067,12 +1758,7 @@ mod tests { let b: Vec = (0..100).map(|i| (i * 2) as f32).collect(); let result = dot_f32(&a, &b); let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); - assert!( - (result - expected).abs() < 1.0, - "dot_f32: {} vs {}", - result, - expected - ); + assert!((result - expected).abs() < 1.0, "dot_f32: {} vs {}", result, expected); } #[test] @@ -1187,10 +1873,7 @@ mod tests { // Re-use the exact IEEE 754 scalar functions from simd_avx512 pub use crate::simd_avx512::{ - f16_to_f32_ieee754, - f32_to_f16_ieee754_rne, - f16_to_f32_batch_ieee754, - f32_to_f16_batch_ieee754_rne, + f16_to_f32_batch_ieee754, f16_to_f32_ieee754, f32_to_f16_batch_ieee754_rne, f32_to_f16_ieee754_rne, }; // ── Trick 1: Double-f16 (Error-Free Split) ────────────────────────────── @@ -1230,7 +1913,7 @@ pub use crate::simd_avx512::{ pub fn f16_double_encode(value: f32) -> (u16, u16) { let hi = f32_to_f16_ieee754_rne(value); let hi_f32 = f16_to_f32_ieee754(hi); // exact (lossless widening) - let residual = value - hi_f32; // exact (f32 subtraction) + let residual = value - hi_f32; // exact (f32 subtraction) let lo = f32_to_f16_ieee754_rne(residual); (hi, lo) } @@ -1375,22 +2058,35 @@ impl F16Scaler { assert!(min_val < max_val, "min must be less than max"); let abs_max = min_val.abs().max(max_val.abs()); if abs_max < f32::EPSILON { - return Self { scale: 1.0, inv_scale: 1.0 }; + return Self { + scale: 1.0, + inv_scale: 1.0, + }; } let scale = 1.0 / abs_max; - Self { scale, inv_scale: abs_max } + Self { + scale, + inv_scale: abs_max, + } } /// Create by scanning data for min/max. pub fn from_data(data: &[f32]) -> Self { if data.is_empty() { - return Self { scale: 1.0, inv_scale: 1.0 }; + return Self { + scale: 1.0, + inv_scale: 1.0, + }; } let mut min = f32::INFINITY; let mut max = f32::NEG_INFINITY; for &v in data { - if v < min { min = v; } - if v > max { max = v; } + if v < min { + min = v; + } + if v > max { + max = v; + } } Self::from_range(min, max) } @@ -1452,10 +2148,17 @@ mod f16_precision_tests { let (hi, lo) = f16_double_encode(value); let double_err = (value - f16_double_decode(hi, lo)).abs(); - assert!(double_err < single_err, - "double should be better: single={:.8} double={:.8}", single_err, double_err); - assert!(double_err < single_err / 100.0, - "double should be >100× better: ratio={:.0}", single_err / double_err); + assert!( + double_err < single_err, + "double should be better: single={:.8} double={:.8}", + single_err, + double_err + ); + assert!( + double_err < single_err / 100.0, + "double should be >100× better: ratio={:.0}", + single_err / double_err + ); } #[test] @@ -1487,15 +2190,18 @@ mod f16_precision_tests { #[test] fn kahan_dot_vs_f64_reference() { - let a: Vec = (0..64).map(|i| f32_to_f16_ieee754_rne(i as f32 * 0.1)).collect(); - let b: Vec = (0..64).map(|i| f32_to_f16_ieee754_rne(1.0 - i as f32 * 0.01)).collect(); + let a: Vec = (0..64) + .map(|i| f32_to_f16_ieee754_rne(i as f32 * 0.1)) + .collect(); + let b: Vec = (0..64) + .map(|i| f32_to_f16_ieee754_rne(1.0 - i as f32 * 0.01)) + .collect(); let dot = f16_kahan_dot(&a, &b); let mut ref_sum = 0.0f64; for i in 0..64 { ref_sum += f16_to_f32_ieee754(a[i]) as f64 * f16_to_f32_ieee754(b[i]) as f64; } - assert!((dot as f64 - ref_sum).abs() < 0.01, - "got={} expected={}", dot, ref_sum); + assert!((dot as f64 - ref_sum).abs() < 0.01, "got={} expected={}", dot, ref_sum); } #[test] @@ -1503,19 +2209,29 @@ mod f16_precision_tests { let data: Vec = (0..100).map(|i| 0.001 + (i as f32) * 0.00004).collect(); let no_scale: Vec = data.iter().map(|&v| f32_to_f16_ieee754_rne(v)).collect(); - let no_scale_err: f64 = data.iter().enumerate() - .map(|(i, &v)| (v as f64 - f16_to_f32_ieee754(no_scale[i]) as f64).powi(2)).sum(); + let no_scale_err: f64 = data + .iter() + .enumerate() + .map(|(i, &v)| (v as f64 - f16_to_f32_ieee754(no_scale[i]) as f64).powi(2)) + .sum(); let scaler = F16Scaler::from_data(&data); let mut scaled = vec![0u16; 100]; scaler.encode_batch(&data, &mut scaled); let mut back = vec![0.0f32; 100]; scaler.decode_batch(&scaled, &mut back); - let scaled_err: f64 = data.iter().enumerate() - .map(|(i, &v)| (v as f64 - back[i] as f64).powi(2)).sum(); + let scaled_err: f64 = data + .iter() + .enumerate() + .map(|(i, &v)| (v as f64 - back[i] as f64).powi(2)) + .sum(); - assert!(scaled_err < no_scale_err, - "scaling should help: unscaled={:.2e} scaled={:.2e}", no_scale_err, scaled_err); + assert!( + scaled_err < no_scale_err, + "scaling should help: unscaled={:.2e} scaled={:.2e}", + no_scale_err, + scaled_err + ); } #[test] @@ -1528,8 +2244,7 @@ mod f16_precision_tests { scaler.decode_batch(&enc, &mut dec); for i in 0..50 { let err = (data[i] - dec[i]).abs(); - assert!(err < data[i].abs() * 0.01 + 1e-6, - "at {}: {} → {} err={}", i, data[i], dec[i], err); + assert!(err < data[i].abs() * 0.01 + 1e-6, "at {}: {} → {} err={}", i, data[i], dec[i], err); } } } diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index 1c6b8592..fa064234 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -33,8 +33,8 @@ use core::arch::x86_64::*; use core::fmt; use core::ops::{ - Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, - Mul, MulAssign, Neg, Not, Shl, Shr, Sub, SubAssign, + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, + Neg, Not, Shl, Shr, Sub, SubAssign, }; // ============================================================================ @@ -175,10 +175,7 @@ impl F32x16 { pub fn abs(self) -> Self { unsafe { let mask = _mm512_set1_epi32(0x7FFF_FFFFi32); - Self(_mm512_castsi512_ps(_mm512_and_si512( - _mm512_castps_si512(self.0), - mask, - ))) + Self(_mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(self.0), mask))) } } @@ -264,10 +261,7 @@ impl Neg for F32x16 { fn neg(self) -> Self { unsafe { let sign = _mm512_set1_epi32(i32::MIN); // 0x80000000 - Self(_mm512_castsi512_ps(_mm512_xor_si512( - _mm512_castps_si512(self.0), - sign, - ))) + Self(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(self.0), sign))) } } } @@ -402,10 +396,7 @@ impl F64x8 { pub fn abs(self) -> Self { unsafe { let mask = _mm512_set1_epi64(0x7FFF_FFFF_FFFF_FFFFi64); - Self(_mm512_castsi512_pd(_mm512_and_si512( - _mm512_castpd_si512(self.0), - mask, - ))) + Self(_mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(self.0), mask))) } } @@ -467,10 +458,7 @@ impl Neg for F64x8 { fn neg(self) -> Self { unsafe { let sign = _mm512_set1_epi64(i64::MIN); // 0x8000000000000000 - Self(_mm512_castsi512_pd(_mm512_xor_si512( - _mm512_castpd_si512(self.0), - sign, - ))) + Self(_mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(self.0), sign))) } } } @@ -605,17 +593,19 @@ impl U8x64 { pub fn shr_epi16(self, imm: u32) -> Self { // _mm512_srli_epi16 shifts each 16-bit lane right // Use match for const immediate (intrinsic requires const) - Self(unsafe { match imm { - 1 => _mm512_srli_epi16(self.0, 1), - 2 => _mm512_srli_epi16(self.0, 2), - 3 => _mm512_srli_epi16(self.0, 3), - 4 => _mm512_srli_epi16(self.0, 4), - 5 => _mm512_srli_epi16(self.0, 5), - 6 => _mm512_srli_epi16(self.0, 6), - 7 => _mm512_srli_epi16(self.0, 7), - 8 => _mm512_srli_epi16(self.0, 8), - _ => _mm512_setzero_si512(), - }}) + Self(unsafe { + match imm { + 1 => _mm512_srli_epi16(self.0, 1), + 2 => _mm512_srli_epi16(self.0, 2), + 3 => _mm512_srli_epi16(self.0, 3), + 4 => _mm512_srli_epi16(self.0, 4), + 5 => _mm512_srli_epi16(self.0, 5), + 6 => _mm512_srli_epi16(self.0, 6), + 7 => _mm512_srli_epi16(self.0, 7), + 8 => _mm512_srli_epi16(self.0, 8), + _ => _mm512_setzero_si512(), + } + }) } /// Saturating unsigned subtraction: max(a - b, 0) per byte. @@ -655,17 +645,19 @@ impl U8x64 { /// Completes the nibble shift pair with `shr_epi16`. #[inline(always)] pub fn shl_epi16(self, imm: u32) -> Self { - Self(unsafe { match imm { - 1 => _mm512_slli_epi16(self.0, 1), - 2 => _mm512_slli_epi16(self.0, 2), - 3 => _mm512_slli_epi16(self.0, 3), - 4 => _mm512_slli_epi16(self.0, 4), - 5 => _mm512_slli_epi16(self.0, 5), - 6 => _mm512_slli_epi16(self.0, 6), - 7 => _mm512_slli_epi16(self.0, 7), - 8 => _mm512_slli_epi16(self.0, 8), - _ => _mm512_setzero_si512(), - }}) + Self(unsafe { + match imm { + 1 => _mm512_slli_epi16(self.0, 1), + 2 => _mm512_slli_epi16(self.0, 2), + 3 => _mm512_slli_epi16(self.0, 3), + 4 => _mm512_slli_epi16(self.0, 4), + 5 => _mm512_slli_epi16(self.0, 5), + 6 => _mm512_slli_epi16(self.0, 6), + 7 => _mm512_slli_epi16(self.0, 7), + 8 => _mm512_slli_epi16(self.0, 8), + _ => _mm512_setzero_si512(), + } + }) } // ── Tier 2: sprite blit + palette LUT + cross-lane shuffle ──────── @@ -748,12 +740,11 @@ impl U8x64 { #[inline(always)] pub fn nibble_popcount_lut() -> Self { // 0x04030302_03020201_03020201_02010100 replicated ×4 - Self(unsafe { _mm512_set4_epi32( - 0x04030302_u32 as i32, - 0x03020201_u32 as i32, - 0x03020201_u32 as i32, - 0x02010100_u32 as i32, - )}) + Self(unsafe { + _mm512_set4_epi32( + 0x04030302_u32 as i32, 0x03020201_u32 as i32, 0x03020201_u32 as i32, 0x02010100_u32 as i32, + ) + }) } } @@ -789,10 +780,7 @@ impl Mul for U8x64 { let packed_lo = _mm512_cvtepi16_epi8(prod_lo); let packed_hi = _mm512_cvtepi16_epi8(prod_hi); - Self(_mm512_inserti64x4::<1>( - _mm512_castsi256_si512(packed_lo), - packed_hi, - )) + Self(_mm512_inserti64x4::<1>(_mm512_castsi256_si512(packed_lo), packed_hi)) } } } @@ -1252,25 +1240,29 @@ impl U16x32 { /// Shift right each 16-bit lane by immediate. #[inline(always)] pub fn shr(self, imm: u32) -> Self { - Self(unsafe { match imm { - 1 => _mm512_srli_epi16(self.0, 1), - 2 => _mm512_srli_epi16(self.0, 2), - 4 => _mm512_srli_epi16(self.0, 4), - 8 => _mm512_srli_epi16(self.0, 8), - _ => _mm512_setzero_si512(), - }}) + Self(unsafe { + match imm { + 1 => _mm512_srli_epi16(self.0, 1), + 2 => _mm512_srli_epi16(self.0, 2), + 4 => _mm512_srli_epi16(self.0, 4), + 8 => _mm512_srli_epi16(self.0, 8), + _ => _mm512_setzero_si512(), + } + }) } /// Shift left each 16-bit lane by immediate. #[inline(always)] pub fn shl(self, imm: u32) -> Self { - Self(unsafe { match imm { - 1 => _mm512_slli_epi16(self.0, 1), - 2 => _mm512_slli_epi16(self.0, 2), - 4 => _mm512_slli_epi16(self.0, 4), - 8 => _mm512_slli_epi16(self.0, 8), - _ => _mm512_setzero_si512(), - }}) + Self(unsafe { + match imm { + 1 => _mm512_slli_epi16(self.0, 1), + 2 => _mm512_slli_epi16(self.0, 2), + 4 => _mm512_slli_epi16(self.0, 4), + 8 => _mm512_slli_epi16(self.0, 8), + _ => _mm512_setzero_si512(), + } + }) } /// Multiply and keep low 16 bits (wrapping). @@ -1291,16 +1283,22 @@ impl U16x32 { impl Add for U16x32 { type Output = Self; #[inline(always)] - fn add(self, rhs: Self) -> Self { Self(unsafe { _mm512_add_epi16(self.0, rhs.0) }) } + fn add(self, rhs: Self) -> Self { + Self(unsafe { _mm512_add_epi16(self.0, rhs.0) }) + } } impl Sub for U16x32 { type Output = Self; #[inline(always)] - fn sub(self, rhs: Self) -> Self { Self(unsafe { _mm512_sub_epi16(self.0, rhs.0) }) } + fn sub(self, rhs: Self) -> Self { + Self(unsafe { _mm512_sub_epi16(self.0, rhs.0) }) + } } impl AddAssign for U16x32 { #[inline(always)] - fn add_assign(&mut self, rhs: Self) { self.0 = unsafe { _mm512_add_epi16(self.0, rhs.0) }; } + fn add_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm512_add_epi16(self.0, rhs.0) }; + } } impl fmt::Debug for U16x32 { @@ -1310,7 +1308,9 @@ impl fmt::Debug for U16x32 { } impl PartialEq for U16x32 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ============================================================================ @@ -1593,7 +1593,9 @@ impl fmt::Debug for I8x64 { } } impl PartialEq for I8x64 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ============================================================================ @@ -1673,20 +1675,28 @@ impl I8x32 { impl Add for I8x32 { type Output = Self; #[inline(always)] - fn add(self, rhs: Self) -> Self { Self(unsafe { _mm256_add_epi8(self.0, rhs.0) }) } + fn add(self, rhs: Self) -> Self { + Self(unsafe { _mm256_add_epi8(self.0, rhs.0) }) + } } impl Sub for I8x32 { type Output = Self; #[inline(always)] - fn sub(self, rhs: Self) -> Self { Self(unsafe { _mm256_sub_epi8(self.0, rhs.0) }) } + fn sub(self, rhs: Self) -> Self { + Self(unsafe { _mm256_sub_epi8(self.0, rhs.0) }) + } } impl AddAssign for I8x32 { #[inline(always)] - fn add_assign(&mut self, rhs: Self) { self.0 = unsafe { _mm256_add_epi8(self.0, rhs.0) }; } + fn add_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_add_epi8(self.0, rhs.0) }; + } } impl SubAssign for I8x32 { #[inline(always)] - fn sub_assign(&mut self, rhs: Self) { self.0 = unsafe { _mm256_sub_epi8(self.0, rhs.0) }; } + fn sub_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_sub_epi8(self.0, rhs.0) }; + } } impl fmt::Debug for I8x32 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -1694,7 +1704,9 @@ impl fmt::Debug for I8x32 { } } impl PartialEq for I8x32 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ============================================================================ @@ -1781,7 +1793,9 @@ impl fmt::Debug for I16x32 { } } impl PartialEq for I16x32 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ============================================================================ @@ -1870,20 +1884,28 @@ impl I16x16 { impl Add for I16x16 { type Output = Self; #[inline(always)] - fn add(self, rhs: Self) -> Self { Self(unsafe { _mm256_add_epi16(self.0, rhs.0) }) } + fn add(self, rhs: Self) -> Self { + Self(unsafe { _mm256_add_epi16(self.0, rhs.0) }) + } } impl Sub for I16x16 { type Output = Self; #[inline(always)] - fn sub(self, rhs: Self) -> Self { Self(unsafe { _mm256_sub_epi16(self.0, rhs.0) }) } + fn sub(self, rhs: Self) -> Self { + Self(unsafe { _mm256_sub_epi16(self.0, rhs.0) }) + } } impl AddAssign for I16x16 { #[inline(always)] - fn add_assign(&mut self, rhs: Self) { self.0 = unsafe { _mm256_add_epi16(self.0, rhs.0) }; } + fn add_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_add_epi16(self.0, rhs.0) }; + } } impl SubAssign for I16x16 { #[inline(always)] - fn sub_assign(&mut self, rhs: Self) { self.0 = unsafe { _mm256_sub_epi16(self.0, rhs.0) }; } + fn sub_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_sub_epi16(self.0, rhs.0) }; + } } impl fmt::Debug for I16x16 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -1891,7 +1913,9 @@ impl fmt::Debug for I16x16 { } } impl PartialEq for I16x16 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ============================================================================ @@ -2335,11 +2359,11 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - if is_x86_feature_detected!("avx512bf16") - && is_x86_feature_detected!("avx512vl") - { + if is_x86_feature_detected!("avx512bf16") && is_x86_feature_detected!("avx512vl") { // SAFETY: feature detection confirmed avx512bf16 + avx512vl - unsafe { convert_bf16_to_f32_avx512bf16(input, output); } + unsafe { + convert_bf16_to_f32_avx512bf16(input, output); + } return; } } @@ -2356,10 +2380,10 @@ pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - if is_x86_feature_detected!("avx512bf16") - && is_x86_feature_detected!("avx512vl") - { - unsafe { convert_f32_to_bf16_avx512bf16(input, output); } + if is_x86_feature_detected!("avx512bf16") && is_x86_feature_detected!("avx512vl") { + unsafe { + convert_f32_to_bf16_avx512bf16(input, output); + } return; } } @@ -2507,21 +2531,16 @@ pub unsafe fn f32_to_bf16_x16_rne(lane: __m512) -> __m256i { let abs_bits = _mm512_and_si512(bits, _mm512_set1_epi32(0x7FFF_FFFFu32 as i32)); let exp_bound = _mm512_set1_epi32(0x0080_0000); let is_sub_or_zero: __mmask16 = _mm512_cmplt_epu32_mask(abs_bits, exp_bound); - let is_nonzero: __mmask16 = - _mm512_cmpgt_epu32_mask(abs_bits, _mm512_setzero_si512()); + let is_nonzero: __mmask16 = _mm512_cmpgt_epu32_mask(abs_bits, _mm512_setzero_si512()); let is_subnormal: __mmask16 = is_sub_or_zero & is_nonzero; - let is_nan: __mmask16 = _mm512_cmpgt_epu32_mask( - abs_bits, - _mm512_set1_epi32(0x7F80_0000u32 as i32), - ); + let is_nan: __mmask16 = _mm512_cmpgt_epu32_mask(abs_bits, _mm512_set1_epi32(0x7F80_0000u32 as i32)); // Blend order: // 1. start from the normal RNE result, // 2. overwrite subnormal lanes with the sign-only zero, // 3. overwrite NaN lanes with the quieted payload. - let with_subnormal = - _mm512_mask_blend_epi32(is_subnormal, normal_out, sign_only); + let with_subnormal = _mm512_mask_blend_epi32(is_subnormal, normal_out, sign_only); let merged = _mm512_mask_blend_epi32(is_nan, with_subnormal, nan_out); // Pack 16 × i32 low-halves into 16 × i16. `_mm512_cvtepi32_epi16` is @@ -2592,30 +2611,18 @@ unsafe fn convert_f32_to_bf16_avx512f_rne(input: &[f32], output: &mut [u16]) { let sign_only = _mm512_and_si512(shifted, _mm512_set1_epi32(0x0000_8000)); let nan_out = _mm512_or_si512(shifted, _mm512_set1_epi32(0x0040)); - let abs_bits = - _mm512_and_si512(bits, _mm512_set1_epi32(0x7FFF_FFFFu32 as i32)); + let abs_bits = _mm512_and_si512(bits, _mm512_set1_epi32(0x7FFF_FFFFu32 as i32)); let exp_bound = _mm512_set1_epi32(0x0080_0000); - let is_sub_or_zero: __mmask16 = - _mm512_cmplt_epu32_mask(abs_bits, exp_bound); - let is_nonzero: __mmask16 = - _mm512_cmpgt_epu32_mask(abs_bits, _mm512_setzero_si512()); + let is_sub_or_zero: __mmask16 = _mm512_cmplt_epu32_mask(abs_bits, exp_bound); + let is_nonzero: __mmask16 = _mm512_cmpgt_epu32_mask(abs_bits, _mm512_setzero_si512()); let is_subnormal: __mmask16 = is_sub_or_zero & is_nonzero; - let is_nan: __mmask16 = _mm512_cmpgt_epu32_mask( - abs_bits, - _mm512_set1_epi32(0x7F80_0000u32 as i32), - ); + let is_nan: __mmask16 = _mm512_cmpgt_epu32_mask(abs_bits, _mm512_set1_epi32(0x7F80_0000u32 as i32)); - let with_subnormal = - _mm512_mask_blend_epi32(is_subnormal, normal_out, sign_only); - let merged = - _mm512_mask_blend_epi32(is_nan, with_subnormal, nan_out); + let with_subnormal = _mm512_mask_blend_epi32(is_subnormal, normal_out, sign_only); + let merged = _mm512_mask_blend_epi32(is_nan, with_subnormal, nan_out); // SAFETY: masked store — only lanes [0, rem) are touched. - _mm512_mask_cvtepi32_storeu_epi16( - output.as_mut_ptr().add(i) as *mut _, - mask, - merged, - ); + _mm512_mask_cvtepi32_storeu_epi16(output.as_mut_ptr().add(i) as *mut _, mask, merged); } } @@ -2635,7 +2642,9 @@ mod bf16_tests { #[test] fn batch_conversion_matches_scalar() { - let input: Vec = (0..100).map(|i| f32_to_bf16_scalar(i as f32 * 0.1 - 5.0)).collect(); + let input: Vec = (0..100) + .map(|i| f32_to_bf16_scalar(i as f32 * 0.1 - 5.0)) + .collect(); let mut batch_output = vec![0.0f32; 100]; bf16_to_f32_batch(&input, &mut batch_output); @@ -2707,9 +2716,8 @@ mod bf16_tests { // Normals across the exponent range. for exp_byte in [1u32, 50, 126, 127, 128, 200, 254] { for mant in [ - 0x0000_00u32, - 0x400000, // halfway-below-LSB for even mantissa - 0x7FFFFF, // top of mantissa (rounding into next exponent) + 0x0000_00u32, 0x400000, // halfway-below-LSB for even mantissa + 0x7FFFFF, // top of mantissa (rounding into next exponent) 0x0080_00, // round bit alone 0x00_FFFF, // sticky bits only 0x01_8000, // round + tie, LSB=1 → round up @@ -2779,18 +2787,13 @@ mod bf16_tests { while i < n { let v = _mm512_loadu_ps(corpus.as_ptr().add(i)); let packed = f32_to_bf16_x16_rne(v); - _mm256_storeu_si256( - rne_out.as_mut_ptr().add(i) as *mut __m256i, - packed, - ); + _mm256_storeu_si256(rne_out.as_mut_ptr().add(i) as *mut __m256i, packed); i += 16; } } // Reference: hardware `_mm512_cvtneps_pbh` if available. - if is_x86_feature_detected!("avx512bf16") - && is_x86_feature_detected!("avx512vl") - { + if is_x86_feature_detected!("avx512bf16") && is_x86_feature_detected!("avx512vl") { let mut hw_out: Vec = vec![0; corpus.len()]; unsafe { // SAFETY: feature detection confirmed avx512bf16 + avx512vl. @@ -2811,7 +2814,8 @@ mod bf16_tests { } } assert_eq!( - mismatches, 0, + mismatches, + 0, "byte-equality with _mm512_cvtneps_pbh failed on {} / {} inputs", mismatches, corpus.len() @@ -2823,15 +2827,15 @@ mod bf16_tests { // walking the Intel SDM VCVTNEPS2BF16 pseudocode by hand. Do not // regenerate these — they are the published oracle. let reference: &[(u32, u16)] = &[ - (0x0000_0000, 0x0000), // +0 - (0x8000_0000, 0x8000), // -0 - (0x3F80_0000, 0x3F80), // 1.0 - (0xBF80_0000, 0xBF80), // -1.0 - (0x7F80_0000, 0x7F80), // +Inf - (0xFF80_0000, 0xFF80), // -Inf - (0x7FC0_0000, 0x7FC0), // canonical qNaN - (0x7F80_0001, 0x7FC0), // sNaN → qNaN - (0x7FBF_FFFF, 0x7FFF), // sNaN payload → QNaN'd + (0x0000_0000, 0x0000), // +0 + (0x8000_0000, 0x8000), // -0 + (0x3F80_0000, 0x3F80), // 1.0 + (0xBF80_0000, 0xBF80), // -1.0 + (0x7F80_0000, 0x7F80), // +Inf + (0xFF80_0000, 0xFF80), // -Inf + (0x7FC0_0000, 0x7FC0), // canonical qNaN + (0x7F80_0001, 0x7FC0), // sNaN → qNaN + (0x7FBF_FFFF, 0x7FFF), // sNaN payload → QNaN'd // Halfway, LSB=0 → round down (stay even). // f32 bits = 0x3F80_8000 (1 + 2^-8). Kept LSB = 0, ties. (0x3F80_8000, 0x3F80), @@ -2858,8 +2862,7 @@ mod bf16_tests { // And run the SIMD path on a padded batch of those same inputs // so the routine's SIMD code path is actually exercised. - let mut batch: Vec = - reference.iter().map(|&(b, _)| f32::from_bits(b)).collect(); + let mut batch: Vec = reference.iter().map(|&(b, _)| f32::from_bits(b)).collect(); while batch.len() % 16 != 0 { batch.push(0.0); } @@ -2868,10 +2871,7 @@ mod bf16_tests { // SAFETY: avx512f confirmed above. let v = _mm512_loadu_ps(batch.as_ptr()); let packed = f32_to_bf16_x16_rne(v); - _mm256_storeu_si256( - simd_out.as_mut_ptr() as *mut __m256i, - packed, - ); + _mm256_storeu_si256(simd_out.as_mut_ptr() as *mut __m256i, packed); } for (i, &(in_bits, expected)) in reference.iter().enumerate() { assert_eq!( @@ -2917,10 +2917,7 @@ mod bf16_tests { while i < n { let v = _mm512_loadu_ps(cases.as_ptr().add(i)); let packed = f32_to_bf16_x16_rne(v); - _mm256_storeu_si256( - out.as_mut_ptr().add(i) as *mut __m256i, - packed, - ); + _mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut __m256i, packed); i += 16; } } @@ -2932,17 +2929,15 @@ mod bf16_tests { } let bf16_mant_lsb = got & 0x0001; assert_eq!( - bf16_mant_lsb, 0, + bf16_mant_lsb, + 0, "round-to-even failed for input idx={idx} bits=0x{:08X}: bf16=0x{got:04X}", v.to_bits() ); // Also cross-check with the scalar reference. let scalar = f32_to_bf16_scalar_rne(v); - assert_eq!( - got, scalar, - "SIMD vs scalar RNE disagree for 0x{:08X}", v.to_bits() - ); + assert_eq!(got, scalar, "SIMD vs scalar RNE disagree for 0x{:08X}", v.to_bits()); } } @@ -2969,11 +2964,7 @@ mod bf16_tests { for (i, &v) in input.iter().enumerate() { let expected = f32_to_bf16_scalar_rne(v); - assert_eq!( - batch_out[i], expected, - "batch RNE mismatch len={len} idx={i} bits=0x{:08X}", - v.to_bits() - ); + assert_eq!(batch_out[i], expected, "batch RNE mismatch len={len} idx={i} bits=0x{:08X}", v.to_bits()); } } } @@ -3040,7 +3031,7 @@ pub fn f16_to_f32_ieee754(bits: u16) -> f32 { // Normalize: find leading 1 in mantissa, adjust exponent let mut m = mant; let mut e: i32 = 1 - 15; // subnormal effective exponent = 1 - bias - // Shift mantissa left until the implicit 1 is in bit 10 + // Shift mantissa left until the implicit 1 is in bit 10 while m & 0x400 == 0 { m <<= 1; e -= 1; @@ -3106,7 +3097,7 @@ pub fn f32_to_f16_ieee754_rne(v: f32) -> u16 { let shift = (-14 - unbiased) as u32; // Add implicit 1 to f32 mantissa, then shift right let full_mant = mant | 0x800000; // 24 bits with implicit 1 - // We need to map 24-bit mantissa to 10-bit with proper shift + // We need to map 24-bit mantissa to 10-bit with proper shift let total_shift = 13 + shift; // 13 to go from 23→10, plus extra for subnormal // Round-to-nearest-even @@ -3118,7 +3109,11 @@ pub fn f32_to_f16_ieee754_rne(v: f32) -> u16 { truncated + 1 } else if remainder == halfway { // Ties to even: round up if truncated is odd - if truncated & 1 != 0 { truncated + 1 } else { truncated } + if truncated & 1 != 0 { + truncated + 1 + } else { + truncated + } } else { truncated }; @@ -3130,7 +3125,7 @@ pub fn f32_to_f16_ieee754_rne(v: f32) -> u16 { } else { // Normal f16 range let h_exp = (unbiased + 15) as u32; // rebias: +15 - // Round mantissa from 23 bits to 10 bits using RNE + // Round mantissa from 23 bits to 10 bits using RNE let truncated = mant >> 13; let remainder = mant & 0x1FFF; // lower 13 bits let halfway = 0x1000; // 2^12 @@ -3138,7 +3133,11 @@ pub fn f32_to_f16_ieee754_rne(v: f32) -> u16 { let rounded = if remainder > halfway { truncated + 1 } else if remainder == halfway { - if truncated & 1 != 0 { truncated + 1 } else { truncated } + if truncated & 1 != 0 { + truncated + 1 + } else { + truncated + } } else { truncated }; @@ -3175,13 +3174,13 @@ pub fn f16_to_f32_batch_ieee754(input: &[u16], output: &mut [f32]) { for c in 0..chunks16 { unsafe { // SAFETY: avx512f + f16c verified above. - let src = _mm256_loadu_si256(input[c*16..].as_ptr() as *const __m256i); + let src = _mm256_loadu_si256(input[c * 16..].as_ptr() as *const __m256i); let dst = _mm512_cvtph_ps(src); - _mm512_storeu_ps(output[c*16..].as_mut_ptr(), dst); + _mm512_storeu_ps(output[c * 16..].as_mut_ptr(), dst); } } // Scalar tail - for i in (chunks16*16)..n { + for i in (chunks16 * 16)..n { output[i] = f16_to_f32_ieee754(input[i]); } return; @@ -3192,12 +3191,12 @@ pub fn f16_to_f32_batch_ieee754(input: &[u16], output: &mut [f32]) { for c in 0..chunks8 { unsafe { // SAFETY: f16c verified above. - let src = _mm_loadu_si128(input[c*8..].as_ptr() as *const __m128i); + let src = _mm_loadu_si128(input[c * 8..].as_ptr() as *const __m128i); let dst = _mm256_cvtph_ps(src); - _mm256_storeu_ps(output[c*8..].as_mut_ptr(), dst); + _mm256_storeu_ps(output[c * 8..].as_mut_ptr(), dst); } } - for i in (chunks8*8)..n { + for i in (chunks8 * 8)..n { output[i] = f16_to_f32_ieee754(input[i]); } return; @@ -3225,13 +3224,13 @@ pub fn f32_to_f16_batch_ieee754_rne(input: &[f32], output: &mut [u16]) { for c in 0..chunks16 { unsafe { // SAFETY: avx512f + f16c verified above. - let src = _mm512_loadu_ps(input[c*16..].as_ptr()); + let src = _mm512_loadu_ps(input[c * 16..].as_ptr()); // imm8=0x00: _MM_FROUND_TO_NEAREST_INT (RNE) let dst: __m256i = _mm512_cvtps_ph::<0x00>(src); - _mm256_storeu_si256(output[c*16..].as_mut_ptr() as *mut __m256i, dst); + _mm256_storeu_si256(output[c * 16..].as_mut_ptr() as *mut __m256i, dst); } } - for i in (chunks16*16)..n { + for i in (chunks16 * 16)..n { output[i] = f32_to_f16_ieee754_rne(input[i]); } return; @@ -3242,12 +3241,12 @@ pub fn f32_to_f16_batch_ieee754_rne(input: &[f32], output: &mut [u16]) { for c in 0..chunks8 { unsafe { // SAFETY: f16c verified above. - let src = _mm256_loadu_ps(input[c*8..].as_ptr()); + let src = _mm256_loadu_ps(input[c * 8..].as_ptr()); let dst: __m128i = _mm256_cvtps_ph::<0x00>(src); - _mm_storeu_si128(output[c*8..].as_mut_ptr() as *mut __m128i, dst); + _mm_storeu_si128(output[c * 8..].as_mut_ptr() as *mut __m128i, dst); } } - for i in (chunks8*8)..n { + for i in (chunks8 * 8)..n { output[i] = f32_to_f16_ieee754_rne(input[i]); } return; @@ -3267,17 +3266,17 @@ mod f16_tests { #[test] fn f16_ieee754_exact_values() { // IEEE 754 binary16 exact test vectors - assert_eq!(f16_to_f32_ieee754(0x0000), 0.0); // +0 - assert_eq!(f16_to_f32_ieee754(0x8000), -0.0); // −0 - assert_eq!(f16_to_f32_ieee754(0x3C00), 1.0); // 1.0 - assert_eq!(f16_to_f32_ieee754(0xBC00), -1.0); // −1.0 - assert_eq!(f16_to_f32_ieee754(0x4000), 2.0); // 2.0 - assert_eq!(f16_to_f32_ieee754(0x3800), 0.5); // 0.5 - assert_eq!(f16_to_f32_ieee754(0x7BFF), 65504.0); // max normal - assert!(f16_to_f32_ieee754(0x7C00).is_infinite()); // +Inf - assert!(f16_to_f32_ieee754(0xFC00).is_infinite()); // −Inf - assert!(f16_to_f32_ieee754(0x7C01).is_nan()); // NaN - // Smallest positive subnormal: 2^(−24) ≈ 5.96e-8 + assert_eq!(f16_to_f32_ieee754(0x0000), 0.0); // +0 + assert_eq!(f16_to_f32_ieee754(0x8000), -0.0); // −0 + assert_eq!(f16_to_f32_ieee754(0x3C00), 1.0); // 1.0 + assert_eq!(f16_to_f32_ieee754(0xBC00), -1.0); // −1.0 + assert_eq!(f16_to_f32_ieee754(0x4000), 2.0); // 2.0 + assert_eq!(f16_to_f32_ieee754(0x3800), 0.5); // 0.5 + assert_eq!(f16_to_f32_ieee754(0x7BFF), 65504.0); // max normal + assert!(f16_to_f32_ieee754(0x7C00).is_infinite()); // +Inf + assert!(f16_to_f32_ieee754(0xFC00).is_infinite()); // −Inf + assert!(f16_to_f32_ieee754(0x7C01).is_nan()); // NaN + // Smallest positive subnormal: 2^(−24) ≈ 5.96e-8 let smallest_sub = f16_to_f32_ieee754(0x0001); assert!((smallest_sub - 5.960464e-8).abs() < 1e-14); } @@ -3290,8 +3289,7 @@ mod f16_tests { let h = (exp << 10) | mant; let f = f16_to_f32_ieee754(h); let back = f32_to_f16_ieee754_rne(f); - assert_eq!(h, back, - "roundtrip failed: 0x{:04X} → {} → 0x{:04X}", h, f, back); + assert_eq!(h, back, "roundtrip failed: 0x{:04X} → {} → 0x{:04X}", h, f, back); } } } @@ -3300,15 +3298,13 @@ mod f16_tests { fn f16_exact_representable_values() { // Values that are exactly representable in f16 must roundtrip perfectly let exact_values: &[f32] = &[ - 0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 0.25, 0.125, - 65504.0, -65504.0, // max f16 + 0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 0.25, 0.125, 65504.0, -65504.0, // max f16 0.000061035156, // smallest normal f16 (2^-14) ]; for &v in exact_values { let h = f32_to_f16_ieee754_rne(v); let back = f16_to_f32_ieee754(h); - assert_eq!(v, back, - "exact value roundtrip failed: {} → 0x{:04X} → {}", v, h, back); + assert_eq!(v, back, "exact value roundtrip failed: {} → 0x{:04X} → {}", v, h, back); } } @@ -3321,18 +3317,25 @@ mod f16_tests { #[test] fn f16_batch_matches_scalar() { - let input: Vec = (0..200).map(|i| { - let v = (i as f32 - 100.0) * 0.5; - f32_to_f16_ieee754_rne(v) - }).collect(); + let input: Vec = (0..200) + .map(|i| { + let v = (i as f32 - 100.0) * 0.5; + f32_to_f16_ieee754_rne(v) + }) + .collect(); let mut batch_out = vec![0.0f32; 200]; f16_to_f32_batch_ieee754(&input, &mut batch_out); for (i, &h) in input.iter().enumerate() { let scalar = f16_to_f32_ieee754(h); - assert_eq!(batch_out[i].to_bits(), scalar.to_bits(), + assert_eq!( + batch_out[i].to_bits(), + scalar.to_bits(), "batch/scalar mismatch at {}: batch=0x{:08X} scalar=0x{:08X}", - i, batch_out[i].to_bits(), scalar.to_bits()); + i, + batch_out[i].to_bits(), + scalar.to_bits() + ); } } @@ -3344,9 +3347,11 @@ mod f16_tests { for (i, &v) in input.iter().enumerate() { let scalar = f32_to_f16_ieee754_rne(v); - assert_eq!(batch_out[i], scalar, + assert_eq!( + batch_out[i], scalar, "f32→f16 batch/scalar mismatch at {}: input={} batch=0x{:04X} scalar=0x{:04X}", - i, v, batch_out[i], scalar); + i, v, batch_out[i], scalar + ); } } } @@ -3382,8 +3387,8 @@ mod u8x64_rasterizer_tests { let a = U8x64::splat(10); let b = U8x64::splat(5); assert_eq!(a.cmpgt_mask(b), u64::MAX); // all greater - assert_eq!(b.cmpgt_mask(a), 0); // none greater - assert_eq!(a.cmpgt_mask(a), 0); // equal = not greater + assert_eq!(b.cmpgt_mask(a), 0); // none greater + assert_eq!(a.cmpgt_mask(a), 0); // equal = not greater } #[test] @@ -3409,7 +3414,8 @@ mod u8x64_rasterizer_tests { #[test] fn shl_epi16_shift_4() { let mut data = [0u8; 64]; - data[0] = 0x0F; data[1] = 0x00; // u16 = 0x000F + data[0] = 0x0F; + data[1] = 0x00; // u16 = 0x000F let v = U8x64::from_slice(&data); let shifted = v.shl_epi16(4); let mut out = [0u8; 64]; @@ -3441,11 +3447,15 @@ mod u8x64_rasterizer_tests { #[test] fn permute_bytes_identity() { let mut data = [0u8; 64]; - for i in 0..64 { data[i] = i as u8; } + for i in 0..64 { + data[i] = i as u8; + } let v = U8x64::from_slice(&data); // Identity permutation let mut idx = [0u8; 64]; - for i in 0..64 { idx[i] = i as u8; } + for i in 0..64 { + idx[i] = i as u8; + } let perm = v.permute_bytes(U8x64::from_slice(&idx)); let mut out = [0u8; 64]; perm.copy_to_slice(&mut out); @@ -3455,21 +3465,27 @@ mod u8x64_rasterizer_tests { #[test] fn permute_bytes_reverse() { let mut data = [0u8; 64]; - for i in 0..64 { data[i] = i as u8; } + for i in 0..64 { + data[i] = i as u8; + } let v = U8x64::from_slice(&data); // Reverse permutation let mut idx = [0u8; 64]; - for i in 0..64 { idx[i] = (63 - i) as u8; } + for i in 0..64 { + idx[i] = (63 - i) as u8; + } let perm = v.permute_bytes(U8x64::from_slice(&idx)); let mut out = [0u8; 64]; perm.copy_to_slice(&mut out); - for i in 0..64 { assert_eq!(out[i], (63 - i) as u8); } + for i in 0..64 { + assert_eq!(out[i], (63 - i) as u8); + } } } #[cfg(test)] mod tier3_tests { - use super::{U8x64, U16x32}; + use super::{U16x32, U8x64}; #[test] fn movemask_all_zero() { @@ -3486,8 +3502,8 @@ mod tier3_tests { #[test] fn movemask_selective() { let mut data = [0u8; 64]; - data[0] = 0x80; // MSB set → bit 0 - data[3] = 0xFF; // MSB set → bit 3 + data[0] = 0x80; // MSB set → bit 0 + data[3] = 0xFF; // MSB set → bit 3 data[63] = 0x80; // MSB set → bit 63 let v = U8x64::from_slice(&data); let mask = v.movemask(); @@ -3515,21 +3531,29 @@ mod tier3_tests { #[test] fn u16x32_from_u8x64_lo() { let mut data = [0u8; 64]; - for i in 0..32 { data[i] = (i + 1) as u8; } + for i in 0..32 { + data[i] = (i + 1) as u8; + } let v = U8x64::from_slice(&data); let wide = U16x32::from_u8x64_lo(v); let arr = wide.to_array(); - for i in 0..32 { assert_eq!(arr[i], (i + 1) as u16); } + for i in 0..32 { + assert_eq!(arr[i], (i + 1) as u16); + } } #[test] fn u16x32_from_u8x64_hi() { let mut data = [0u8; 64]; - for i in 32..64 { data[i] = i as u8; } + for i in 32..64 { + data[i] = i as u8; + } let v = U8x64::from_slice(&data); let wide = U16x32::from_u8x64_hi(v); let arr = wide.to_array(); - for i in 0..32 { assert_eq!(arr[i], (32 + i) as u16); } + for i in 0..32 { + assert_eq!(arr[i], (32 + i) as u16); + } } #[test] @@ -3592,7 +3616,7 @@ mod tier3_tests { #[cfg(test)] mod int_simd_tests { - use crate::simd::{I8x32, I8x64, I16x16, I16x32}; + use crate::simd::{I16x16, I16x32, I8x32, I8x64}; #[test] fn i8x64_add_pair_to_constant() { @@ -3715,14 +3739,10 @@ mod int_simd_tests { #[test] fn i16x16_add_round_trip_and_min() { - let a = I16x16::from_array([ - -100, -50, 0, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, - ]); + let a = I16x16::from_array([-100, -50, 0, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200]); let b = I16x16::splat(10); let c = a.add(b); - let exp: [i16; 16] = [ - -90, -40, 10, 60, 110, 210, 310, 410, 510, 610, 710, 810, 910, 1010, 1110, 1210, - ]; + let exp: [i16; 16] = [-90, -40, 10, 60, 110, 210, 310, 410, 510, 610, 710, 810, 910, 1010, 1110, 1210]; assert_eq!(c.to_array(), exp); let mn = a.min(I16x16::splat(0)); diff --git a/src/simd_half.rs b/src/simd_half.rs index 6dd6564f..327f0943 100644 --- a/src/simd_half.rs +++ b/src/simd_half.rs @@ -374,8 +374,12 @@ mod tests { #[test] fn bf16x16_add_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.5)).collect(); - let b_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.25 + 1.0)).collect(); + let a_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.5)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.25 + 1.0)) + .collect(); let va = BF16x16::from_slice(&a_vals); let vb = BF16x16::from_slice(&b_vals); @@ -392,8 +396,12 @@ mod tests { #[test] fn bf16x16_sub_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(10.0 + i as f32)).collect(); - let b_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.5)).collect(); + let a_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(10.0 + i as f32)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let result = BF16x16::from_slice(&a_vals).sub(BF16x16::from_slice(&b_vals)); let mut out = vec![BF16::ZERO; 16]; @@ -407,8 +415,12 @@ mod tests { #[test] fn bf16x16_mul_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.5 + 0.1)).collect(); - let b_vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.3 + 0.2)).collect(); + let a_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.5 + 0.1)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.3 + 0.2)) + .collect(); let result = BF16x16::from_slice(&a_vals).mul(BF16x16::from_slice(&b_vals)); let mut out = vec![BF16::ZERO; 16]; @@ -422,25 +434,31 @@ mod tests { #[test] fn bf16x16_fma_matches_scalar() { - let a: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 + 1.0)).collect(); - let b: Vec = (0..16).map(|i| BF16::from_f32_rounded(0.5 * i as f32)).collect(); - let c: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 0.1)).collect(); + let a: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 + 1.0)) + .collect(); + let b: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(0.5 * i as f32)) + .collect(); + let c: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.1)) + .collect(); let result = BF16x16::from_slice(&a).fma(BF16x16::from_slice(&b), BF16x16::from_slice(&c)); let mut out = vec![BF16::ZERO; 16]; result.copy_to_slice(&mut out); for i in 0..16 { - let expected = BF16::from_f32_rounded( - a[i].to_f32().mul_add(b[i].to_f32(), c[i].to_f32()), - ); + let expected = BF16::from_f32_rounded(a[i].to_f32().mul_add(b[i].to_f32(), c[i].to_f32())); assert_eq!(out[i], expected, "BF16x16 fma mismatch at lane {}", i); } } #[test] fn bf16x16_to_f32x16_roundtrip() { - let vals: Vec = (0..16).map(|i| BF16::from_f32_rounded(i as f32 * 1.5)).collect(); + let vals: Vec = (0..16) + .map(|i| BF16::from_f32_rounded(i as f32 * 1.5)) + .collect(); let v = BF16x16::from_slice(&vals); let f32s = v.to_f32x16(); @@ -463,8 +481,12 @@ mod tests { #[test] fn f16x16_add_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.5)).collect(); - let b_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.25 + 1.0)).collect(); + let a_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.5)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.25 + 1.0)) + .collect(); let result = F16x16::from_slice(&a_vals).add(F16x16::from_slice(&b_vals)); let mut out = vec![F16::ZERO; 16]; @@ -478,8 +500,12 @@ mod tests { #[test] fn f16x16_mul_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.5 + 0.1)).collect(); - let b_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.3 + 0.2)).collect(); + let a_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.5 + 0.1)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.3 + 0.2)) + .collect(); let result = F16x16::from_slice(&a_vals).mul(F16x16::from_slice(&b_vals)); let mut out = vec![F16::ZERO; 16]; @@ -493,8 +519,12 @@ mod tests { #[test] fn f16x16_sub_matches_scalar() { - let a_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(10.0 + i as f32)).collect(); - let b_vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.5)).collect(); + let a_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(10.0 + i as f32)) + .collect(); + let b_vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let result = F16x16::from_slice(&a_vals).sub(F16x16::from_slice(&b_vals)); let mut out = vec![F16::ZERO; 16]; @@ -508,25 +538,31 @@ mod tests { #[test] fn f16x16_fma_matches_scalar() { - let a: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 + 1.0)).collect(); - let b: Vec = (0..16).map(|i| F16::from_f32_rounded(0.5 * i as f32)).collect(); - let c: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 0.1)).collect(); + let a: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 + 1.0)) + .collect(); + let b: Vec = (0..16) + .map(|i| F16::from_f32_rounded(0.5 * i as f32)) + .collect(); + let c: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 0.1)) + .collect(); let result = F16x16::from_slice(&a).fma(F16x16::from_slice(&b), F16x16::from_slice(&c)); let mut out = vec![F16::ZERO; 16]; result.copy_to_slice(&mut out); for i in 0..16 { - let expected = F16::from_f32_rounded( - a[i].to_f32().mul_add(b[i].to_f32(), c[i].to_f32()), - ); + let expected = F16::from_f32_rounded(a[i].to_f32().mul_add(b[i].to_f32(), c[i].to_f32())); assert_eq!(out[i], expected, "F16x16 fma mismatch at lane {}", i); } } #[test] fn f16x16_to_f32x16_roundtrip() { - let vals: Vec = (0..16).map(|i| F16::from_f32_rounded(i as f32 * 1.5)).collect(); + let vals: Vec = (0..16) + .map(|i| F16::from_f32_rounded(i as f32 * 1.5)) + .collect(); let v = F16x16::from_slice(&vals); let f32s = v.to_f32x16(); @@ -551,7 +587,9 @@ mod tests { fn add_bf16_inplace_tail_15() { let n = 15; let mut dst: Vec = (0..n).map(|i| BF16::from_f32_rounded(i as f32)).collect(); - let src: Vec = (0..n).map(|i| BF16::from_f32_rounded(i as f32 * 0.5)).collect(); + let src: Vec = (0..n) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let expected: Vec = (0..n) .map(|i| BF16::from_f32_rounded(i as f32 + i as f32 * 0.5)) .collect(); @@ -566,7 +604,9 @@ mod tests { fn add_bf16_inplace_tail_17() { let n = 17; let mut dst: Vec = (0..n).map(|i| BF16::from_f32_rounded(i as f32)).collect(); - let src: Vec = (0..n).map(|i| BF16::from_f32_rounded(i as f32 * 0.5)).collect(); + let src: Vec = (0..n) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let expected: Vec = (0..n) .map(|i| BF16::from_f32_rounded(i as f32 + i as f32 * 0.5)) .collect(); @@ -611,7 +651,9 @@ mod tests { #[test] fn cast_bf16_f32_roundtrip() { - let bf16_vals: Vec = (0..33).map(|i| BF16::from_f32_rounded(i as f32 * 0.75)).collect(); + let bf16_vals: Vec = (0..33) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.75)) + .collect(); let mut f32_buf = vec![0.0f32; 33]; let mut bf16_buf = vec![BF16::ZERO; 33]; @@ -626,7 +668,9 @@ mod tests { #[test] fn cast_f16_f32_roundtrip() { // Use small values to stay within F16 range - let f16_vals: Vec = (0..33).map(|i| F16::from_f32_rounded(i as f32 * 0.5)).collect(); + let f16_vals: Vec = (0..33) + .map(|i| F16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let mut f32_buf = vec![0.0f32; 33]; let mut f16_buf = vec![F16::ZERO; 33]; @@ -643,7 +687,9 @@ mod tests { #[test] fn mul_bf16_inplace_basic() { let n = 17; - let mut dst: Vec = (0..n).map(|i| BF16::from_f32_rounded(i as f32 + 1.0)).collect(); + let mut dst: Vec = (0..n) + .map(|i| BF16::from_f32_rounded(i as f32 + 1.0)) + .collect(); let src: Vec = (0..n).map(|_| BF16::from_f32_rounded(2.0)).collect(); let expected: Vec = (0..n) .map(|i| BF16::from_f32_rounded((i as f32 + 1.0) * 2.0)) @@ -661,7 +707,9 @@ mod tests { fn add_f16_inplace_tail_17() { let n = 17; let mut dst: Vec = (0..n).map(|i| F16::from_f32_rounded(i as f32)).collect(); - let src: Vec = (0..n).map(|i| F16::from_f32_rounded(i as f32 * 0.5)).collect(); + let src: Vec = (0..n) + .map(|i| F16::from_f32_rounded(i as f32 * 0.5)) + .collect(); let expected: Vec = (0..n) .map(|i| F16::from_f32_rounded(i as f32 + i as f32 * 0.5)) .collect(); @@ -677,7 +725,9 @@ mod tests { #[test] fn mul_f16_inplace_basic() { let n = 17; - let mut dst: Vec = (0..n).map(|i| F16::from_f32_rounded(i as f32 + 1.0)).collect(); + let mut dst: Vec = (0..n) + .map(|i| F16::from_f32_rounded(i as f32 + 1.0)) + .collect(); let src: Vec = (0..n).map(|_| F16::from_f32_rounded(2.0)).collect(); let expected: Vec = (0..n) .map(|i| F16::from_f32_rounded((i as f32 + 1.0) * 2.0)) diff --git a/src/simd_neon.rs b/src/simd_neon.rs index 9c523f77..e7d36776 100644 --- a/src/simd_neon.rs +++ b/src/simd_neon.rs @@ -77,9 +77,9 @@ pub unsafe fn hamming_u8x16(a: &[u8; 16], b: &[u8; 16]) -> u32 { let xored = veorq_u8(va, vb); let counts = vcntq_u8(xored); // Widen and sum: u8→u16→u32→u64→scalar - let sum16 = vpaddlq_u8(counts); // 8×u16 - let sum32 = vpaddlq_u16(sum16); // 4×u32 - let sum64 = vpaddlq_u32(sum32); // 2×u64 + let sum16 = vpaddlq_u8(counts); // 8×u16 + let sum32 = vpaddlq_u16(sum16); // 4×u32 + let sum64 = vpaddlq_u32(sum32); // 2×u64 vgetq_lane_u64(sum64, 0) as u32 + vgetq_lane_u64(sum64, 1) as u32 } @@ -92,7 +92,7 @@ pub unsafe fn base17_l1_neon(a: &[i16; 17], b: &[i16; 17]) -> i32 { let va0 = vld1q_s16(a.as_ptr()); let vb0 = vld1q_s16(b.as_ptr()); let diff0 = vabdq_s16(va0, vb0); // absolute difference per lane - let sum0 = vpaddlq_s16(diff0); // widen to i32, pairwise add → 4×i32 + let sum0 = vpaddlq_s16(diff0); // widen to i32, pairwise add → 4×i32 // Next 8 elements let va1 = vld1q_s16(a[8..].as_ptr()); @@ -113,10 +113,10 @@ pub unsafe fn base17_l1_neon(a: &[i16; 17], b: &[i16; 17]) -> i32 { /// This is O(N) with NEON FMA — the core of ada-brain inference. #[cfg(target_arch = "aarch64")] pub unsafe fn codebook_gather_f32x4_neon( - centroids: &[f32], // flat array: N_centroids × dim, row-major - indices: &[u8], // which centroids to gather - dim: usize, // must be multiple of 4 - output: &mut [f32], // dim elements, accumulated + centroids: &[f32], // flat array: N_centroids × dim, row-major + indices: &[u8], // which centroids to gather + dim: usize, // must be multiple of 4 + output: &mut [f32], // dim elements, accumulated ) { debug_assert!(dim % 4 == 0); debug_assert!(output.len() >= dim); @@ -144,12 +144,7 @@ pub unsafe fn codebook_gather_f32x4_neon( /// Codebook gather with 2× unroll for A72 dual-pipeline saturation. /// Processes 2 index lookups per iteration to keep both NEON pipes fed. #[cfg(target_arch = "aarch64")] -pub unsafe fn codebook_gather_f32x4_a72( - centroids: &[f32], - indices: &[u8], - dim: usize, - output: &mut [f32], -) { +pub unsafe fn codebook_gather_f32x4_a72(centroids: &[f32], indices: &[u8], dim: usize, output: &mut [f32]) { debug_assert!(dim % 4 == 0); debug_assert!(output.len() >= dim); @@ -207,7 +202,7 @@ pub unsafe fn dot_i8x16_neon(a: &[i8; 16], b: &[i8; 16]) -> i32 { #[cfg(target_arch = "aarch64")] #[target_feature(enable = "dotprod")] pub unsafe fn codebook_gather_i8_dotprod( - centroids_i8: &[i8], // quantized centroids: N × dim (i8) + centroids_i8: &[i8], // quantized centroids: N × dim (i8) indices: &[u8], dim: usize, // must be multiple of 16 output_i32: &mut [i32], // accumulated i32 (dequantize later) @@ -403,9 +398,9 @@ pub fn f16_to_f32_batch(input: &[u16], output: &mut [f32]) { // Pi 5 path: FCVTL (4× f16 → 4× f32 per instruction) let chunks = n / 4; for c in 0..chunks { - let src: &[u16; 4] = input[c*4..c*4+4].try_into().unwrap(); + let src: &[u16; 4] = input[c * 4..c * 4 + 4].try_into().unwrap(); let dst = unsafe { f16x4_to_f32x4(src) }; - output[c*4..c*4+4].copy_from_slice(&dst); + output[c * 4..c * 4 + 4].copy_from_slice(&dst); } // Scalar tail for i in (chunks * 4)..n { @@ -430,9 +425,9 @@ pub fn f32_to_f16_batch(input: &[f32], output: &mut [u16]) { if std::arch::is_aarch64_feature_detected!("fp16") { let chunks = n / 4; for c in 0..chunks { - let src: &[f32; 4] = input[c*4..c*4+4].try_into().unwrap(); + let src: &[f32; 4] = input[c * 4..c * 4 + 4].try_into().unwrap(); let dst = unsafe { f32x4_to_f16x4(src) }; - output[c*4..c*4+4].copy_from_slice(&dst); + output[c * 4..c * 4 + 4].copy_from_slice(&dst); } for i in (chunks * 4)..n { output[i] = f32_to_f16_scalar(input[i]); @@ -468,15 +463,11 @@ pub fn f32_to_f16_batch(input: &[f32], output: &mut [u16]) { pub mod aarch64_simd { use super::*; use core::fmt; - use core::ops::{ - Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign, - }; + use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; // Integer types come from the scalar fallback in simd.rs — they aren't on // the perf-critical f32 BLAS-1 / VML path that this module accelerates. - pub use crate::simd::scalar::{ - I32x16, U32x16, U64x8, - }; + pub use crate::simd::scalar::{I32x16, U32x16, U64x8}; /// 16×f32 backed by 4× NEON `float32x4_t` registers (paired loads). #[derive(Copy, Clone)] @@ -499,17 +490,14 @@ pub mod aarch64_simd { assert!(s.len() >= 16); unsafe { let p = s.as_ptr(); - Self([ - vld1q_f32(p), - vld1q_f32(p.add(4)), - vld1q_f32(p.add(8)), - vld1q_f32(p.add(12)), - ]) + Self([vld1q_f32(p), vld1q_f32(p.add(4)), vld1q_f32(p.add(8)), vld1q_f32(p.add(12))]) } } #[inline(always)] - pub fn from_array(a: [f32; 16]) -> Self { Self::from_slice(&a) } + pub fn from_array(a: [f32; 16]) -> Self { + Self::from_slice(&a) + } #[inline(always)] pub fn to_array(self) -> [f32; 16] { @@ -541,51 +529,43 @@ pub mod aarch64_simd { #[inline(always)] pub fn reduce_min(self) -> f32 { - self.to_array().iter().copied().fold(f32::INFINITY, f32::min) + self.to_array() + .iter() + .copied() + .fold(f32::INFINITY, f32::min) } #[inline(always)] pub fn reduce_max(self) -> f32 { - self.to_array().iter().copied().fold(f32::NEG_INFINITY, f32::max) + self.to_array() + .iter() + .copied() + .fold(f32::NEG_INFINITY, f32::max) } #[inline(always)] pub fn abs(self) -> Self { - unsafe { - Self([ - vabsq_f32(self.0[0]), vabsq_f32(self.0[1]), - vabsq_f32(self.0[2]), vabsq_f32(self.0[3]), - ]) - } + unsafe { Self([vabsq_f32(self.0[0]), vabsq_f32(self.0[1]), vabsq_f32(self.0[2]), vabsq_f32(self.0[3])]) } } #[inline(always)] pub fn sqrt(self) -> Self { unsafe { - Self([ - vsqrtq_f32(self.0[0]), vsqrtq_f32(self.0[1]), - vsqrtq_f32(self.0[2]), vsqrtq_f32(self.0[3]), - ]) + Self([vsqrtq_f32(self.0[0]), vsqrtq_f32(self.0[1]), vsqrtq_f32(self.0[2]), vsqrtq_f32(self.0[3])]) } } #[inline(always)] pub fn round(self) -> Self { unsafe { - Self([ - vrndnq_f32(self.0[0]), vrndnq_f32(self.0[1]), - vrndnq_f32(self.0[2]), vrndnq_f32(self.0[3]), - ]) + Self([vrndnq_f32(self.0[0]), vrndnq_f32(self.0[1]), vrndnq_f32(self.0[2]), vrndnq_f32(self.0[3])]) } } #[inline(always)] pub fn floor(self) -> Self { unsafe { - Self([ - vrndmq_f32(self.0[0]), vrndmq_f32(self.0[1]), - vrndmq_f32(self.0[2]), vrndmq_f32(self.0[3]), - ]) + Self([vrndmq_f32(self.0[0]), vrndmq_f32(self.0[1]), vrndmq_f32(self.0[2]), vrndmq_f32(self.0[3])]) } } @@ -626,53 +606,92 @@ pub mod aarch64_simd { } #[inline(always)] - pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + self.simd_max(lo).simd_min(hi) + } #[inline(always)] pub fn simd_lt(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); + let a = self.to_array(); + let b = other.to_array(); let mut bits: u16 = 0; - for i in 0..16 { if a[i] < b[i] { bits |= 1 << i; } } + for i in 0..16 { + if a[i] < b[i] { + bits |= 1 << i; + } + } F32Mask16(bits) } #[inline(always)] pub fn simd_le(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); + let a = self.to_array(); + let b = other.to_array(); let mut bits: u16 = 0; - for i in 0..16 { if a[i] <= b[i] { bits |= 1 << i; } } + for i in 0..16 { + if a[i] <= b[i] { + bits |= 1 << i; + } + } F32Mask16(bits) } - #[inline(always)] pub fn simd_gt(self, other: Self) -> F32Mask16 { other.simd_lt(self) } - #[inline(always)] pub fn simd_ge(self, other: Self) -> F32Mask16 { other.simd_le(self) } + #[inline(always)] + pub fn simd_gt(self, other: Self) -> F32Mask16 { + other.simd_lt(self) + } + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F32Mask16 { + other.simd_le(self) + } #[inline(always)] pub fn simd_eq(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); + let a = self.to_array(); + let b = other.to_array(); let mut bits: u16 = 0; - for i in 0..16 { if a[i] == b[i] { bits |= 1 << i; } } + for i in 0..16 { + if a[i] == b[i] { + bits |= 1 << i; + } + } F32Mask16(bits) } #[inline(always)] pub fn simd_ne(self, other: Self) -> F32Mask16 { - let a = self.to_array(); let b = other.to_array(); + let a = self.to_array(); + let b = other.to_array(); let mut bits: u16 = 0; - for i in 0..16 { if a[i] != b[i] { bits |= 1 << i; } } + for i in 0..16 { + if a[i] != b[i] { + bits |= 1 << i; + } + } F32Mask16(bits) } #[inline(always)] pub fn to_bits(self) -> U32x16 { let a = self.to_array(); - let mut o = [0u32; 16]; for i in 0..16 { o[i] = a[i].to_bits(); } U32x16(o) + let mut o = [0u32; 16]; + for i in 0..16 { + o[i] = a[i].to_bits(); + } + U32x16(o) } #[inline(always)] pub fn from_bits(bits: U32x16) -> Self { - let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = f32::from_bits(bits.0[i]); } + let mut o = [0.0f32; 16]; + for i in 0..16 { + o[i] = f32::from_bits(bits.0[i]); + } Self::from_array(o) } #[inline(always)] pub fn cast_i32(self) -> I32x16 { let a = self.to_array(); - let mut o = [0i32; 16]; for i in 0..16 { o[i] = a[i] as i32; } I32x16(o) + let mut o = [0i32; 16]; + for i in 0..16 { + o[i] = a[i] as i32; + } + I32x16(o) } } @@ -682,8 +701,10 @@ pub mod aarch64_simd { fn add(self, rhs: Self) -> Self { unsafe { Self([ - vaddq_f32(self.0[0], rhs.0[0]), vaddq_f32(self.0[1], rhs.0[1]), - vaddq_f32(self.0[2], rhs.0[2]), vaddq_f32(self.0[3], rhs.0[3]), + vaddq_f32(self.0[0], rhs.0[0]), + vaddq_f32(self.0[1], rhs.0[1]), + vaddq_f32(self.0[2], rhs.0[2]), + vaddq_f32(self.0[3], rhs.0[3]), ]) } } @@ -694,8 +715,10 @@ pub mod aarch64_simd { fn sub(self, rhs: Self) -> Self { unsafe { Self([ - vsubq_f32(self.0[0], rhs.0[0]), vsubq_f32(self.0[1], rhs.0[1]), - vsubq_f32(self.0[2], rhs.0[2]), vsubq_f32(self.0[3], rhs.0[3]), + vsubq_f32(self.0[0], rhs.0[0]), + vsubq_f32(self.0[1], rhs.0[1]), + vsubq_f32(self.0[2], rhs.0[2]), + vsubq_f32(self.0[3], rhs.0[3]), ]) } } @@ -706,8 +729,10 @@ pub mod aarch64_simd { fn mul(self, rhs: Self) -> Self { unsafe { Self([ - vmulq_f32(self.0[0], rhs.0[0]), vmulq_f32(self.0[1], rhs.0[1]), - vmulq_f32(self.0[2], rhs.0[2]), vmulq_f32(self.0[3], rhs.0[3]), + vmulq_f32(self.0[0], rhs.0[0]), + vmulq_f32(self.0[1], rhs.0[1]), + vmulq_f32(self.0[2], rhs.0[2]), + vmulq_f32(self.0[3], rhs.0[3]), ]) } } @@ -718,26 +743,43 @@ pub mod aarch64_simd { fn div(self, rhs: Self) -> Self { unsafe { Self([ - vdivq_f32(self.0[0], rhs.0[0]), vdivq_f32(self.0[1], rhs.0[1]), - vdivq_f32(self.0[2], rhs.0[2]), vdivq_f32(self.0[3], rhs.0[3]), + vdivq_f32(self.0[0], rhs.0[0]), + vdivq_f32(self.0[1], rhs.0[1]), + vdivq_f32(self.0[2], rhs.0[2]), + vdivq_f32(self.0[3], rhs.0[3]), ]) } } } - impl AddAssign for F32x16 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } - impl SubAssign for F32x16 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } - impl MulAssign for F32x16 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } - impl DivAssign for F32x16 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } + impl AddAssign for F32x16 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + impl SubAssign for F32x16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + impl MulAssign for F32x16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + impl DivAssign for F32x16 { + #[inline(always)] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } impl Neg for F32x16 { type Output = Self; #[inline(always)] fn neg(self) -> Self { - unsafe { - Self([ - vnegq_f32(self.0[0]), vnegq_f32(self.0[1]), - vnegq_f32(self.0[2]), vnegq_f32(self.0[3]), - ]) - } + unsafe { Self([vnegq_f32(self.0[0]), vnegq_f32(self.0[1]), vnegq_f32(self.0[2]), vnegq_f32(self.0[3])]) } } } impl fmt::Debug for F32x16 { @@ -746,18 +788,27 @@ pub mod aarch64_simd { } } impl PartialEq for F32x16 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } + } + impl Default for F32x16 { + fn default() -> Self { + Self::splat(0.0) + } } - impl Default for F32x16 { fn default() -> Self { Self::splat(0.0) } } #[derive(Copy, Clone, Debug)] pub struct F32Mask16(pub u16); impl F32Mask16 { #[inline(always)] pub fn select(self, true_val: F32x16, false_val: F32x16) -> F32x16 { - let t = true_val.to_array(); let f = false_val.to_array(); + let t = true_val.to_array(); + let f = false_val.to_array(); let mut o = [0.0f32; 16]; - for i in 0..16 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + for i in 0..16 { + o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; + } F32x16::from_array(o) } } @@ -783,17 +834,14 @@ pub mod aarch64_simd { assert!(s.len() >= 8); unsafe { let p = s.as_ptr(); - Self([ - vld1q_f64(p), - vld1q_f64(p.add(2)), - vld1q_f64(p.add(4)), - vld1q_f64(p.add(6)), - ]) + Self([vld1q_f64(p), vld1q_f64(p.add(2)), vld1q_f64(p.add(4)), vld1q_f64(p.add(6))]) } } #[inline(always)] - pub fn from_array(a: [f64; 8]) -> Self { Self::from_slice(&a) } + pub fn from_array(a: [f64; 8]) -> Self { + Self::from_slice(&a) + } #[inline(always)] pub fn to_array(self) -> [f64; 8] { @@ -825,51 +873,43 @@ pub mod aarch64_simd { #[inline(always)] pub fn reduce_min(self) -> f64 { - self.to_array().iter().copied().fold(f64::INFINITY, f64::min) + self.to_array() + .iter() + .copied() + .fold(f64::INFINITY, f64::min) } #[inline(always)] pub fn reduce_max(self) -> f64 { - self.to_array().iter().copied().fold(f64::NEG_INFINITY, f64::max) + self.to_array() + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max) } #[inline(always)] pub fn abs(self) -> Self { - unsafe { - Self([ - vabsq_f64(self.0[0]), vabsq_f64(self.0[1]), - vabsq_f64(self.0[2]), vabsq_f64(self.0[3]), - ]) - } + unsafe { Self([vabsq_f64(self.0[0]), vabsq_f64(self.0[1]), vabsq_f64(self.0[2]), vabsq_f64(self.0[3])]) } } #[inline(always)] pub fn sqrt(self) -> Self { unsafe { - Self([ - vsqrtq_f64(self.0[0]), vsqrtq_f64(self.0[1]), - vsqrtq_f64(self.0[2]), vsqrtq_f64(self.0[3]), - ]) + Self([vsqrtq_f64(self.0[0]), vsqrtq_f64(self.0[1]), vsqrtq_f64(self.0[2]), vsqrtq_f64(self.0[3])]) } } #[inline(always)] pub fn round(self) -> Self { unsafe { - Self([ - vrndnq_f64(self.0[0]), vrndnq_f64(self.0[1]), - vrndnq_f64(self.0[2]), vrndnq_f64(self.0[3]), - ]) + Self([vrndnq_f64(self.0[0]), vrndnq_f64(self.0[1]), vrndnq_f64(self.0[2]), vrndnq_f64(self.0[3])]) } } #[inline(always)] pub fn floor(self) -> Self { unsafe { - Self([ - vrndmq_f64(self.0[0]), vrndmq_f64(self.0[1]), - vrndmq_f64(self.0[2]), vrndmq_f64(self.0[3]), - ]) + Self([vrndmq_f64(self.0[0]), vrndmq_f64(self.0[1]), vrndmq_f64(self.0[2]), vrndmq_f64(self.0[3])]) } } @@ -910,29 +950,50 @@ pub mod aarch64_simd { } #[inline(always)] - pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + self.simd_max(lo).simd_min(hi) + } #[inline(always)] pub fn simd_ge(self, other: Self) -> F64Mask8 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u8 = 0; for i in 0..8 { if a[i] >= b[i] { bits |= 1 << i; } } + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u8 = 0; + for i in 0..8 { + if a[i] >= b[i] { + bits |= 1 << i; + } + } F64Mask8(bits) } #[inline(always)] pub fn simd_le(self, other: Self) -> F64Mask8 { - let a = self.to_array(); let b = other.to_array(); - let mut bits: u8 = 0; for i in 0..8 { if a[i] <= b[i] { bits |= 1 << i; } } + let a = self.to_array(); + let b = other.to_array(); + let mut bits: u8 = 0; + for i in 0..8 { + if a[i] <= b[i] { + bits |= 1 << i; + } + } F64Mask8(bits) } #[inline(always)] pub fn to_bits(self) -> U64x8 { let a = self.to_array(); - let mut o = [0u64; 8]; for i in 0..8 { o[i] = a[i].to_bits(); } U64x8(o) + let mut o = [0u64; 8]; + for i in 0..8 { + o[i] = a[i].to_bits(); + } + U64x8(o) } #[inline(always)] pub fn from_bits(bits: U64x8) -> Self { - let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = f64::from_bits(bits.0[i]); } + let mut o = [0.0f64; 8]; + for i in 0..8 { + o[i] = f64::from_bits(bits.0[i]); + } Self::from_array(o) } } @@ -943,8 +1004,10 @@ pub mod aarch64_simd { fn add(self, rhs: Self) -> Self { unsafe { Self([ - vaddq_f64(self.0[0], rhs.0[0]), vaddq_f64(self.0[1], rhs.0[1]), - vaddq_f64(self.0[2], rhs.0[2]), vaddq_f64(self.0[3], rhs.0[3]), + vaddq_f64(self.0[0], rhs.0[0]), + vaddq_f64(self.0[1], rhs.0[1]), + vaddq_f64(self.0[2], rhs.0[2]), + vaddq_f64(self.0[3], rhs.0[3]), ]) } } @@ -955,8 +1018,10 @@ pub mod aarch64_simd { fn sub(self, rhs: Self) -> Self { unsafe { Self([ - vsubq_f64(self.0[0], rhs.0[0]), vsubq_f64(self.0[1], rhs.0[1]), - vsubq_f64(self.0[2], rhs.0[2]), vsubq_f64(self.0[3], rhs.0[3]), + vsubq_f64(self.0[0], rhs.0[0]), + vsubq_f64(self.0[1], rhs.0[1]), + vsubq_f64(self.0[2], rhs.0[2]), + vsubq_f64(self.0[3], rhs.0[3]), ]) } } @@ -967,8 +1032,10 @@ pub mod aarch64_simd { fn mul(self, rhs: Self) -> Self { unsafe { Self([ - vmulq_f64(self.0[0], rhs.0[0]), vmulq_f64(self.0[1], rhs.0[1]), - vmulq_f64(self.0[2], rhs.0[2]), vmulq_f64(self.0[3], rhs.0[3]), + vmulq_f64(self.0[0], rhs.0[0]), + vmulq_f64(self.0[1], rhs.0[1]), + vmulq_f64(self.0[2], rhs.0[2]), + vmulq_f64(self.0[3], rhs.0[3]), ]) } } @@ -979,26 +1046,43 @@ pub mod aarch64_simd { fn div(self, rhs: Self) -> Self { unsafe { Self([ - vdivq_f64(self.0[0], rhs.0[0]), vdivq_f64(self.0[1], rhs.0[1]), - vdivq_f64(self.0[2], rhs.0[2]), vdivq_f64(self.0[3], rhs.0[3]), + vdivq_f64(self.0[0], rhs.0[0]), + vdivq_f64(self.0[1], rhs.0[1]), + vdivq_f64(self.0[2], rhs.0[2]), + vdivq_f64(self.0[3], rhs.0[3]), ]) } } } - impl AddAssign for F64x8 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } - impl SubAssign for F64x8 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } - impl MulAssign for F64x8 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } - impl DivAssign for F64x8 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } + impl AddAssign for F64x8 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + impl SubAssign for F64x8 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + impl MulAssign for F64x8 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + impl DivAssign for F64x8 { + #[inline(always)] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } impl Neg for F64x8 { type Output = Self; #[inline(always)] fn neg(self) -> Self { - unsafe { - Self([ - vnegq_f64(self.0[0]), vnegq_f64(self.0[1]), - vnegq_f64(self.0[2]), vnegq_f64(self.0[3]), - ]) - } + unsafe { Self([vnegq_f64(self.0[0]), vnegq_f64(self.0[1]), vnegq_f64(self.0[2]), vnegq_f64(self.0[3])]) } } } impl fmt::Debug for F64x8 { @@ -1007,25 +1091,36 @@ pub mod aarch64_simd { } } impl PartialEq for F64x8 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } + } + impl Default for F64x8 { + fn default() -> Self { + Self::splat(0.0) + } } - impl Default for F64x8 { fn default() -> Self { Self::splat(0.0) } } #[derive(Copy, Clone, Debug)] pub struct F64Mask8(pub u8); impl F64Mask8 { #[inline(always)] pub fn select(self, true_val: F64x8, false_val: F64x8) -> F64x8 { - let t = true_val.to_array(); let f = false_val.to_array(); + let t = true_val.to_array(); + let f = false_val.to_array(); let mut o = [0.0f64; 8]; - for i in 0..8 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + for i in 0..8 { + o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; + } F64x8::from_array(o) } } // Lowercase aliases (consumer-API parity) - #[allow(non_camel_case_types)] pub type f32x16 = F32x16; - #[allow(non_camel_case_types)] pub type f64x8 = F64x8; + #[allow(non_camel_case_types)] + pub type f32x16 = F32x16; + #[allow(non_camel_case_types)] + pub type f64x8 = F64x8; } #[cfg(all(target_arch = "aarch64", test))] @@ -1052,7 +1147,9 @@ mod neon_pair_tests { let b = F32x16::splat(3.0); let c = F32x16::splat(1.0); let r = a.mul_add(b, c).to_array(); - for &v in &r { assert_eq!(v, 7.0); } + for &v in &r { + assert_eq!(v, 7.0); + } } #[test] @@ -1110,10 +1207,14 @@ impl I8x16 { pub const LANES: usize = 16; #[inline(always)] - pub fn splat(v: i8) -> Self { Self(unsafe { vdupq_n_s8(v) }) } + pub fn splat(v: i8) -> Self { + Self(unsafe { vdupq_n_s8(v) }) + } #[inline(always)] - pub fn zero() -> Self { Self(unsafe { vdupq_n_s8(0) }) } + pub fn zero() -> Self { + Self(unsafe { vdupq_n_s8(0) }) + } #[inline(always)] pub fn from_slice(s: &[i8]) -> Self { @@ -1139,10 +1240,22 @@ impl I8x16 { unsafe { vst1q_s8(s.as_mut_ptr(), self.0) }; } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s8(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s8(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s8(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s8(self.0, other.0) }) } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_s8(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_s8(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_s8(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_s8(self.0, other.0) }) + } /// Compare-greater-than: returns 16-bit mask. Bit i set where self[i] > other[i]. #[inline(always)] @@ -1151,7 +1264,11 @@ impl I8x16 { let cmp = vcgtq_s8(self.0, other.0); // uint8x16_t, 0xFF where true let arr: [u8; 16] = core::mem::transmute(cmp); let mut m: u16 = 0; - for i in 0..16 { if arr[i] != 0 { m |= 1u16 << i; } } + for i in 0..16 { + if arr[i] != 0 { + m |= 1u16 << i; + } + } m } } @@ -1165,7 +1282,9 @@ impl core::fmt::Debug for I8x16 { } #[cfg(target_arch = "aarch64")] impl PartialEq for I8x16 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } #[cfg(target_arch = "aarch64")] @@ -1178,10 +1297,14 @@ impl I16x8 { pub const LANES: usize = 8; #[inline(always)] - pub fn splat(v: i16) -> Self { Self(unsafe { vdupq_n_s16(v) }) } + pub fn splat(v: i16) -> Self { + Self(unsafe { vdupq_n_s16(v) }) + } #[inline(always)] - pub fn zero() -> Self { Self(unsafe { vdupq_n_s16(0) }) } + pub fn zero() -> Self { + Self(unsafe { vdupq_n_s16(0) }) + } #[inline(always)] pub fn from_slice(s: &[i16]) -> Self { @@ -1207,10 +1330,22 @@ impl I16x8 { unsafe { vst1q_s16(s.as_mut_ptr(), self.0) }; } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s16(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s16(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s16(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s16(self.0, other.0) }) } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_s16(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_s16(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_s16(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_s16(self.0, other.0) }) + } /// Compare-greater-than: returns 8-bit mask. Bit i set where self[i] > other[i]. #[inline(always)] @@ -1219,7 +1354,11 @@ impl I16x8 { let cmp = vcgtq_s16(self.0, other.0); // uint16x8_t, 0xFFFF where true let arr: [u16; 8] = core::mem::transmute(cmp); let mut m: u8 = 0; - for i in 0..8 { if arr[i] != 0 { m |= 1u8 << i; } } + for i in 0..8 { + if arr[i] != 0 { + m |= 1u8 << i; + } + } m } } @@ -1233,7 +1372,9 @@ impl core::fmt::Debug for I16x8 { } #[cfg(target_arch = "aarch64")] impl PartialEq for I16x8 { - fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } // ═══════════════════════════════════════════════════════════════════════════ @@ -1249,22 +1390,50 @@ pub struct U8x16(pub uint8x16_t); #[cfg(target_arch = "aarch64")] impl U8x16 { pub const LANES: usize = 16; - #[inline(always)] pub fn splat(v: u8) -> Self { Self(unsafe { vdupq_n_u8(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u8(0) }) } - #[inline(always)] pub fn from_slice(s: &[u8]) -> Self { - assert!(s.len() >= 16); Self(unsafe { vld1q_u8(s.as_ptr()) }) - } - #[inline(always)] pub fn from_array(arr: [u8; 16]) -> Self { Self(unsafe { vld1q_u8(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [u8; 16] { - let mut arr = [0u8; 16]; unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; arr - } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [u8]) { - assert!(s.len() >= 16); unsafe { vst1q_u8(s.as_mut_ptr(), self.0) }; - } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u8(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u8(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u8(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u8(self.0, other.0) }) } + #[inline(always)] + pub fn splat(v: u8) -> Self { + Self(unsafe { vdupq_n_u8(v) }) + } + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_u8(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 16); + Self(unsafe { vld1q_u8(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [u8; 16]) -> Self { + Self(unsafe { vld1q_u8(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [u8; 16] { + let mut arr = [0u8; 16]; + unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; + arr + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 16); + unsafe { vst1q_u8(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_u8(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_u8(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_u8(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_u8(self.0, other.0) }) + } } #[cfg(target_arch = "aarch64")] @@ -1275,22 +1444,50 @@ pub struct U16x8(pub uint16x8_t); #[cfg(target_arch = "aarch64")] impl U16x8 { pub const LANES: usize = 8; - #[inline(always)] pub fn splat(v: u16) -> Self { Self(unsafe { vdupq_n_u16(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u16(0) }) } - #[inline(always)] pub fn from_slice(s: &[u16]) -> Self { - assert!(s.len() >= 8); Self(unsafe { vld1q_u16(s.as_ptr()) }) - } - #[inline(always)] pub fn from_array(arr: [u16; 8]) -> Self { Self(unsafe { vld1q_u16(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [u16; 8] { - let mut arr = [0u16; 8]; unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; arr - } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [u16]) { - assert!(s.len() >= 8); unsafe { vst1q_u16(s.as_mut_ptr(), self.0) }; - } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u16(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u16(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u16(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u16(self.0, other.0) }) } + #[inline(always)] + pub fn splat(v: u16) -> Self { + Self(unsafe { vdupq_n_u16(v) }) + } + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_u16(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 8); + Self(unsafe { vld1q_u16(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [u16; 8]) -> Self { + Self(unsafe { vld1q_u16(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [u16; 8] { + let mut arr = [0u16; 8]; + unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; + arr + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 8); + unsafe { vst1q_u16(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_u16(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_u16(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_u16(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_u16(self.0, other.0) }) + } } #[cfg(target_arch = "aarch64")] @@ -1301,22 +1498,50 @@ pub struct U32x4(pub uint32x4_t); #[cfg(target_arch = "aarch64")] impl U32x4 { pub const LANES: usize = 4; - #[inline(always)] pub fn splat(v: u32) -> Self { Self(unsafe { vdupq_n_u32(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u32(0) }) } - #[inline(always)] pub fn from_slice(s: &[u32]) -> Self { - assert!(s.len() >= 4); Self(unsafe { vld1q_u32(s.as_ptr()) }) - } - #[inline(always)] pub fn from_array(arr: [u32; 4]) -> Self { Self(unsafe { vld1q_u32(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [u32; 4] { - let mut arr = [0u32; 4]; unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; arr - } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [u32]) { - assert!(s.len() >= 4); unsafe { vst1q_u32(s.as_mut_ptr(), self.0) }; - } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u32(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u32(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u32(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u32(self.0, other.0) }) } + #[inline(always)] + pub fn splat(v: u32) -> Self { + Self(unsafe { vdupq_n_u32(v) }) + } + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_u32(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[u32]) -> Self { + assert!(s.len() >= 4); + Self(unsafe { vld1q_u32(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [u32; 4]) -> Self { + Self(unsafe { vld1q_u32(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [u32; 4] { + let mut arr = [0u32; 4]; + unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; + arr + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u32]) { + assert!(s.len() >= 4); + unsafe { vst1q_u32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_u32(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_u32(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_u32(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_u32(self.0, other.0) }) + } } #[cfg(target_arch = "aarch64")] @@ -1327,27 +1552,53 @@ pub struct U64x2(pub uint64x2_t); #[cfg(target_arch = "aarch64")] impl U64x2 { pub const LANES: usize = 2; - #[inline(always)] pub fn splat(v: u64) -> Self { Self(unsafe { vdupq_n_u64(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u64(0) }) } - #[inline(always)] pub fn from_slice(s: &[u64]) -> Self { - assert!(s.len() >= 2); Self(unsafe { vld1q_u64(s.as_ptr()) }) + #[inline(always)] + pub fn splat(v: u64) -> Self { + Self(unsafe { vdupq_n_u64(v) }) + } + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_u64(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[u64]) -> Self { + assert!(s.len() >= 2); + Self(unsafe { vld1q_u64(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [u64; 2]) -> Self { + Self(unsafe { vld1q_u64(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [u64; 2] { + let mut arr = [0u64; 2]; + unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; + arr } - #[inline(always)] pub fn from_array(arr: [u64; 2]) -> Self { Self(unsafe { vld1q_u64(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [u64; 2] { - let mut arr = [0u64; 2]; unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; arr + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u64]) { + assert!(s.len() >= 2); + unsafe { vst1q_u64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_u64(self.0, other.0) }) } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [u64]) { - assert!(s.len() >= 2); unsafe { vst1q_u64(s.as_mut_ptr(), self.0) }; + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_u64(self.0, other.0) }) } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u64(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u64(self.0, other.0) }) } // NEON has no vminq_u64 / vmaxq_u64 — scalar fallback - #[inline(always)] pub fn min(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) } - #[inline(always)] pub fn max(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) } } @@ -1360,22 +1611,50 @@ pub struct I32x4(pub int32x4_t); #[cfg(target_arch = "aarch64")] impl I32x4 { pub const LANES: usize = 4; - #[inline(always)] pub fn splat(v: i32) -> Self { Self(unsafe { vdupq_n_s32(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s32(0) }) } - #[inline(always)] pub fn from_slice(s: &[i32]) -> Self { - assert!(s.len() >= 4); Self(unsafe { vld1q_s32(s.as_ptr()) }) - } - #[inline(always)] pub fn from_array(arr: [i32; 4]) -> Self { Self(unsafe { vld1q_s32(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [i32; 4] { - let mut arr = [0i32; 4]; unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; arr - } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [i32]) { - assert!(s.len() >= 4); unsafe { vst1q_s32(s.as_mut_ptr(), self.0) }; - } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s32(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s32(self.0, other.0) }) } - #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s32(self.0, other.0) }) } - #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s32(self.0, other.0) }) } + #[inline(always)] + pub fn splat(v: i32) -> Self { + Self(unsafe { vdupq_n_s32(v) }) + } + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_s32(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[i32]) -> Self { + assert!(s.len() >= 4); + Self(unsafe { vld1q_s32(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [i32; 4]) -> Self { + Self(unsafe { vld1q_s32(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [i32; 4] { + let mut arr = [0i32; 4]; + unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; + arr + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i32]) { + assert!(s.len() >= 4); + unsafe { vst1q_s32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_s32(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_s32(self.0, other.0) }) + } + #[inline(always)] + pub fn min(self, other: Self) -> Self { + Self(unsafe { vminq_s32(self.0, other.0) }) + } + #[inline(always)] + pub fn max(self, other: Self) -> Self { + Self(unsafe { vmaxq_s32(self.0, other.0) }) + } } #[cfg(target_arch = "aarch64")] @@ -1386,32 +1665,57 @@ pub struct I64x2(pub int64x2_t); #[cfg(target_arch = "aarch64")] impl I64x2 { pub const LANES: usize = 2; - #[inline(always)] pub fn splat(v: i64) -> Self { Self(unsafe { vdupq_n_s64(v) }) } - #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s64(0) }) } - #[inline(always)] pub fn from_slice(s: &[i64]) -> Self { - assert!(s.len() >= 2); Self(unsafe { vld1q_s64(s.as_ptr()) }) + #[inline(always)] + pub fn splat(v: i64) -> Self { + Self(unsafe { vdupq_n_s64(v) }) } - #[inline(always)] pub fn from_array(arr: [i64; 2]) -> Self { Self(unsafe { vld1q_s64(arr.as_ptr()) }) } - #[inline(always)] pub fn to_array(self) -> [i64; 2] { - let mut arr = [0i64; 2]; unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; arr + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { vdupq_n_s64(0) }) + } + #[inline(always)] + pub fn from_slice(s: &[i64]) -> Self { + assert!(s.len() >= 2); + Self(unsafe { vld1q_s64(s.as_ptr()) }) + } + #[inline(always)] + pub fn from_array(arr: [i64; 2]) -> Self { + Self(unsafe { vld1q_s64(arr.as_ptr()) }) + } + #[inline(always)] + pub fn to_array(self) -> [i64; 2] { + let mut arr = [0i64; 2]; + unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; + arr } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [i64]) { - assert!(s.len() >= 2); unsafe { vst1q_s64(s.as_mut_ptr(), self.0) }; + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i64]) { + assert!(s.len() >= 2); + unsafe { vst1q_s64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] + pub fn add(self, other: Self) -> Self { + Self(unsafe { vaddq_s64(self.0, other.0) }) + } + #[inline(always)] + pub fn sub(self, other: Self) -> Self { + Self(unsafe { vsubq_s64(self.0, other.0) }) } - #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s64(self.0, other.0) }) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s64(self.0, other.0) }) } // NEON has no vminq_s64 / vmaxq_s64 — scalar fallback - #[inline(always)] pub fn min(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); + #[inline(always)] + pub fn min(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) } - #[inline(always)] pub fn max(self, other: Self) -> Self { - let a = self.to_array(); let b = other.to_array(); + #[inline(always)] + pub fn max(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) } } - // ── Polyfills for wider lanes (scalar arrays) ───────────────────────────── #[allow(unused_macros)] @@ -1423,40 +1727,74 @@ macro_rules! neon_int_polyfill { impl $name { pub const LANES: usize = $lanes; - #[inline(always)] pub fn splat(v: $elem) -> Self { Self([v; $lanes]) } - #[inline(always)] pub fn zero() -> Self { Self([$zero; $lanes]) } - #[inline(always)] pub fn from_slice(s: &[$elem]) -> Self { + #[inline(always)] + pub fn splat(v: $elem) -> Self { + Self([v; $lanes]) + } + #[inline(always)] + pub fn zero() -> Self { + Self([$zero; $lanes]) + } + #[inline(always)] + pub fn from_slice(s: &[$elem]) -> Self { assert!(s.len() >= $lanes); - let mut a = [$zero; $lanes]; a.copy_from_slice(&s[..$lanes]); Self(a) + let mut a = [$zero; $lanes]; + a.copy_from_slice(&s[..$lanes]); + Self(a) + } + #[inline(always)] + pub fn from_array(a: [$elem; $lanes]) -> Self { + Self(a) + } + #[inline(always)] + pub fn to_array(self) -> [$elem; $lanes] { + self.0 } - #[inline(always)] pub fn from_array(a: [$elem; $lanes]) -> Self { Self(a) } - #[inline(always)] pub fn to_array(self) -> [$elem; $lanes] { self.0 } - #[inline(always)] pub fn copy_to_slice(self, s: &mut [$elem]) { - assert!(s.len() >= $lanes); s[..$lanes].copy_from_slice(&self.0); + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [$elem]) { + assert!(s.len() >= $lanes); + s[..$lanes].copy_from_slice(&self.0); } - #[inline(always)] pub fn add(self, other: Self) -> Self { + #[inline(always)] + pub fn add(self, other: Self) -> Self { let mut o = [$zero; $lanes]; - for i in 0..$lanes { o[i] = self.0[i].wrapping_add(other.0[i]); } + for i in 0..$lanes { + o[i] = self.0[i].wrapping_add(other.0[i]); + } Self(o) } - #[inline(always)] pub fn sub(self, other: Self) -> Self { + #[inline(always)] + pub fn sub(self, other: Self) -> Self { let mut o = [$zero; $lanes]; - for i in 0..$lanes { o[i] = self.0[i].wrapping_sub(other.0[i]); } + for i in 0..$lanes { + o[i] = self.0[i].wrapping_sub(other.0[i]); + } Self(o) } - #[inline(always)] pub fn min(self, other: Self) -> Self { + #[inline(always)] + pub fn min(self, other: Self) -> Self { let mut o = [$zero; $lanes]; - for i in 0..$lanes { o[i] = self.0[i].min(other.0[i]); } + for i in 0..$lanes { + o[i] = self.0[i].min(other.0[i]); + } Self(o) } - #[inline(always)] pub fn max(self, other: Self) -> Self { + #[inline(always)] + pub fn max(self, other: Self) -> Self { let mut o = [$zero; $lanes]; - for i in 0..$lanes { o[i] = self.0[i].max(other.0[i]); } + for i in 0..$lanes { + o[i] = self.0[i].max(other.0[i]); + } Self(o) } - #[inline(always)] pub fn cmp_gt(self, other: Self) -> $mask { + #[inline(always)] + pub fn cmp_gt(self, other: Self) -> $mask { let mut m: $mask = 0; - for i in 0..$lanes { if self.0[i] > other.0[i] { m |= (1 as $mask) << i; } } + for i in 0..$lanes { + if self.0[i] > other.0[i] { + m |= (1 as $mask) << i; + } + } m } } @@ -1466,22 +1804,40 @@ macro_rules! neon_int_polyfill { } } impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { self.0 == other.0 } + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } } }; } -#[cfg(target_arch = "aarch64")] neon_int_polyfill!(I8x32, i8, 32, 0i8, u32); -#[cfg(target_arch = "aarch64")] neon_int_polyfill!(I8x64, i8, 64, 0i8, u64); -#[cfg(target_arch = "aarch64")] neon_int_polyfill!(I16x16, i16, 16, 0i16, u16); -#[cfg(target_arch = "aarch64")] neon_int_polyfill!(I16x32, i16, 32, 0i16, u32); +#[cfg(target_arch = "aarch64")] +neon_int_polyfill!(I8x32, i8, 32, 0i8, u32); +#[cfg(target_arch = "aarch64")] +neon_int_polyfill!(I8x64, i8, 64, 0i8, u64); +#[cfg(target_arch = "aarch64")] +neon_int_polyfill!(I16x16, i16, 16, 0i16, u16); +#[cfg(target_arch = "aarch64")] +neon_int_polyfill!(I16x32, i16, 32, 0i16, u32); -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i8x16 = I8x16; -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i16x8 = I16x8; -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i8x32 = I8x32; -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i8x64 = I8x64; -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i16x16 = I16x16; -#[cfg(target_arch = "aarch64")] #[allow(non_camel_case_types)] pub type i16x32 = I16x32; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i8x16 = I8x16; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i16x8 = I16x8; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i8x32 = I8x32; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i8x64 = I8x64; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i16x16 = I16x16; +#[cfg(target_arch = "aarch64")] +#[allow(non_camel_case_types)] +pub type i16x32 = I16x32; // ═══════════════════════════════════════════════════════════════════════════ // Tests (run on x86 as compile-check, actual NEON tests need aarch64) @@ -1498,8 +1854,7 @@ mod tests { let h = f32_to_f16_scalar(v); let back = f16_to_f32_scalar(h); let err = (v - back).abs() / v.abs().max(1e-10); - assert!(err < 0.01 || v == 0.0, - "f16 roundtrip failed for {}: got {}, err={:.4}", v, back, err); + assert!(err < 0.01 || v == 0.0, "f16 roundtrip failed for {}: got {}, err={:.4}", v, back, err); } } @@ -1520,14 +1875,19 @@ mod tests { #[test] fn f16_batch_matches_scalar() { - let input: Vec = (0..100).map(|i| f32_to_f16_scalar(i as f32 * 0.1 - 5.0)).collect(); + let input: Vec = (0..100) + .map(|i| f32_to_f16_scalar(i as f32 * 0.1 - 5.0)) + .collect(); let mut batch_out = vec![0.0f32; 100]; f16_to_f32_batch(&input, &mut batch_out); for (i, &h) in input.iter().enumerate() { let scalar = f16_to_f32_scalar(h); - assert_eq!(batch_out[i], scalar, - "batch/scalar mismatch at {}: batch={} scalar={}", i, batch_out[i], scalar); + assert_eq!( + batch_out[i], scalar, + "batch/scalar mismatch at {}: batch={} scalar={}", + i, batch_out[i], scalar + ); } } @@ -1543,8 +1903,15 @@ mod tests { for i in 0..50 { let err = (input[i] - f32_back[i]).abs(); // f16 has ~3 decimal digits of precision - assert!(err < 0.1 || input[i].abs() < 0.001, - "roundtrip error at {}: {} → {} → {}, err={}", i, input[i], f16_out[i], f32_back[i], err); + assert!( + err < 0.1 || input[i].abs() < 0.001, + "roundtrip error at {}: {} → {} → {}, err={}", + i, + input[i], + f16_out[i], + f32_back[i], + err + ); } } } diff --git a/src/slice.rs b/src/slice.rs index e2ce1e72..8f0abccb 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -39,8 +39,7 @@ use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, Rang /// reverse order. It can also be created with `Slice::from(a..).step_by(-1)`. /// The Python equivalent is `[a::-1]`. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub struct Slice -{ +pub struct Slice { /// start index; negative are counted from the back of the axis pub start: isize, /// end index; negative are counted from the back of the axis; when not present @@ -50,8 +49,7 @@ pub struct Slice pub step: isize, } -impl Slice -{ +impl Slice { /// Create a new `Slice` with the given extents. /// /// See also the `From` impls, converting from ranges; for example @@ -59,8 +57,7 @@ impl Slice /// /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) - pub fn new(start: isize, end: Option, step: isize) -> Slice - { + pub fn new(start: isize, end: Option, step: isize) -> Slice { debug_assert_ne!(step, 0, "Slice::new: step must be nonzero"); Slice { start, end, step } } @@ -71,8 +68,7 @@ impl Slice /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) #[inline] - pub fn step_by(self, step: isize) -> Self - { + pub fn step_by(self, step: isize) -> Self { debug_assert_ne!(step, 0, "Slice::step_by: step must be nonzero"); Slice { step: self.step * step, @@ -116,13 +112,11 @@ pub struct NewAxis; /// with `SliceInfoElem::from(NewAxis)`. The Python equivalent is /// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`. #[derive(Debug, PartialEq, Eq, Hash)] -pub enum SliceInfoElem -{ +pub enum SliceInfoElem { /// A range with step size. `end` is an exclusive index. Negative `start` /// or `end` indexes are counted from the back of the axis. If `end` is /// `None`, the slice extends to the end of the axis. - Slice - { + Slice { /// start index; negative are counted from the back of the axis start: isize, /// end index; negative are counted from the back of the axis; when not present @@ -139,31 +133,25 @@ pub enum SliceInfoElem copy_and_clone! {SliceInfoElem} -impl SliceInfoElem -{ +impl SliceInfoElem { /// Returns `true` if `self` is a `Slice` value. - pub fn is_slice(&self) -> bool - { + pub fn is_slice(&self) -> bool { matches!(self, SliceInfoElem::Slice { .. }) } /// Returns `true` if `self` is an `Index` value. - pub fn is_index(&self) -> bool - { + pub fn is_index(&self) -> bool { matches!(self, SliceInfoElem::Index(_)) } /// Returns `true` if `self` is a `NewAxis` value. - pub fn is_new_axis(&self) -> bool - { + pub fn is_new_axis(&self) -> bool { matches!(self, SliceInfoElem::NewAxis) } } -impl fmt::Display for SliceInfoElem -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result - { +impl fmt::Display for SliceInfoElem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { SliceInfoElem::Index(index) => write!(f, "{}", index)?, SliceInfoElem::Slice { start, end, step } => { @@ -251,11 +239,9 @@ impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, isize); impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, usize); impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, i32); -impl From for Slice -{ +impl From for Slice { #[inline] - fn from(_: RangeFull) -> Slice - { + fn from(_: RangeFull) -> Slice { Slice { start: 0, end: None, @@ -264,11 +250,9 @@ impl From for Slice } } -impl From for SliceInfoElem -{ +impl From for SliceInfoElem { #[inline] - fn from(_: RangeFull) -> SliceInfoElem - { + fn from(_: RangeFull) -> SliceInfoElem { SliceInfoElem::Slice { start: 0, end: None, @@ -277,11 +261,9 @@ impl From for SliceInfoElem } } -impl From for SliceInfoElem -{ +impl From for SliceInfoElem { #[inline] - fn from(s: Slice) -> SliceInfoElem - { + fn from(s: Slice) -> SliceInfoElem { SliceInfoElem::Slice { start: s.start, end: s.end, @@ -304,11 +286,9 @@ impl_sliceinfoelem_from_index!(isize); impl_sliceinfoelem_from_index!(usize); impl_sliceinfoelem_from_index!(i32); -impl From for SliceInfoElem -{ +impl From for SliceInfoElem { #[inline] - fn from(_: NewAxis) -> SliceInfoElem - { + fn from(_: NewAxis) -> SliceInfoElem { SliceInfoElem::NewAxis } } @@ -320,8 +300,7 @@ impl From for SliceInfoElem /// consistent with the `&[SliceInfoElem]` returned by `self.as_ref()` and that /// `self.as_ref()` always returns the same value when called multiple times. #[allow(clippy::missing_safety_doc)] // not implementable downstream -pub unsafe trait SliceArg: AsRef<[SliceInfoElem]> -{ +pub unsafe trait SliceArg: AsRef<[SliceInfoElem]> { /// Dimensionality of the output array. type OutDim: Dimension; @@ -341,13 +320,11 @@ where { type OutDim = T::OutDim; - fn in_ndim(&self) -> usize - { + fn in_ndim(&self) -> usize { T::in_ndim(self) } - fn out_ndim(&self) -> usize - { + fn out_ndim(&self) -> usize { T::out_ndim(self) } @@ -391,30 +368,25 @@ where { type OutDim = Dout; - fn in_ndim(&self) -> usize - { + fn in_ndim(&self) -> usize { self.in_ndim() } - fn out_ndim(&self) -> usize - { + fn out_ndim(&self) -> usize { self.out_ndim() } private_impl! {} } -unsafe impl SliceArg for [SliceInfoElem] -{ +unsafe impl SliceArg for [SliceInfoElem] { type OutDim = IxDyn; - fn in_ndim(&self) -> usize - { + fn in_ndim(&self) -> usize { self.iter().filter(|s| !s.is_new_axis()).count() } - fn out_ndim(&self) -> usize - { + fn out_ndim(&self) -> usize { self.iter().filter(|s| !s.is_index()).count() } @@ -432,8 +404,7 @@ unsafe impl SliceArg for [SliceInfoElem] /// /// [`.slice()`]: crate::ArrayRef::slice #[derive(Debug)] -pub struct SliceInfo -{ +pub struct SliceInfo { in_dim: PhantomData, out_dim: PhantomData, indices: T, @@ -445,8 +416,7 @@ where Dout: Dimension, { type Target = T; - fn deref(&self) -> &Self::Target - { + fn deref(&self) -> &Self::Target { &self.indices } } @@ -487,8 +457,7 @@ where #[doc(hidden)] pub unsafe fn new_unchecked( indices: T, in_dim: PhantomData, out_dim: PhantomData, - ) -> SliceInfo - { + ) -> SliceInfo { if cfg!(debug_assertions) { check_dims_for_sliceinfo::(indices.as_ref()) .expect("`Din` and `Dout` must be consistent with `indices`."); @@ -510,8 +479,7 @@ where /// /// The caller must ensure `indices.as_ref()` always returns the same value /// when called multiple times. - pub unsafe fn new(indices: T) -> Result, ShapeError> - { + pub unsafe fn new(indices: T) -> Result, ShapeError> { check_dims_for_sliceinfo::(indices.as_ref())?; Ok(SliceInfo { in_dim: PhantomData, @@ -526,8 +494,7 @@ where /// If `Din` is a fixed-size dimension type, then this is equivalent to /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `SliceInfoElem` elements. - pub fn in_ndim(&self) -> usize - { + pub fn in_ndim(&self) -> usize { if let Some(ndim) = Din::NDIM { ndim } else { @@ -542,8 +509,7 @@ where /// If `Dout` is a fixed-size dimension type, then this is equivalent to /// `Dout::NDIM.unwrap()`. Otherwise, the value is calculated by iterating /// over the `SliceInfoElem` elements. - pub fn out_ndim(&self) -> usize - { + pub fn out_ndim(&self) -> usize { if let Some(ndim) = Dout::NDIM { ndim } else { @@ -559,8 +525,7 @@ where { type Error = ShapeError; - fn try_from(indices: &'a [SliceInfoElem]) -> Result, ShapeError> - { + fn try_from(indices: &'a [SliceInfoElem]) -> Result, ShapeError> { unsafe { // This is okay because `&[SliceInfoElem]` always returns the same // value for `.as_ref()`. @@ -576,8 +541,7 @@ where { type Error = ShapeError; - fn try_from(indices: Vec) -> Result, Din, Dout>, ShapeError> - { + fn try_from(indices: Vec) -> Result, Din, Dout>, ShapeError> { unsafe { // This is okay because `Vec` always returns the same value for // `.as_ref()`. @@ -588,8 +552,7 @@ where macro_rules! impl_tryfrom_array_for_sliceinfo { ($len:expr) => { - impl TryFrom<[SliceInfoElem; $len]> - for SliceInfo<[SliceInfoElem; $len], Din, Dout> + impl TryFrom<[SliceInfoElem; $len]> for SliceInfo<[SliceInfoElem; $len], Din, Dout> where Din: Dimension, Dout: Dimension, @@ -624,8 +587,7 @@ where Din: Dimension, Dout: Dimension, { - fn as_ref(&self) -> &[SliceInfoElem] - { + fn as_ref(&self) -> &[SliceInfoElem] { self.indices.as_ref() } } @@ -636,8 +598,7 @@ where Din: Dimension, Dout: Dimension, { - fn from(info: &'a SliceInfo) -> SliceInfo<&'a [SliceInfoElem], Din, Dout> - { + fn from(info: &'a SliceInfo) -> SliceInfo<&'a [SliceInfoElem], Din, Dout> { SliceInfo { in_dim: info.in_dim, out_dim: info.out_dim, @@ -660,8 +621,7 @@ where Din: Dimension, Dout: Dimension, { - fn clone(&self) -> Self - { + fn clone(&self) -> Self { SliceInfo { in_dim: PhantomData, out_dim: PhantomData, @@ -672,21 +632,22 @@ where /// Trait for determining dimensionality of input and output for [`s!`] macro. #[doc(hidden)] -pub trait SliceNextDim -{ +pub trait SliceNextDim { /// Number of dimensions that this slicing argument consumes in the input array. type InDim: Dimension; /// Number of dimensions that this slicing argument produces in the output array. type OutDim: Dimension; fn next_in_dim(&self, _: PhantomData) -> PhantomData<>::Output> - where D: Dimension + DimAdd + where + D: Dimension + DimAdd, { PhantomData } fn next_out_dim(&self, _: PhantomData) -> PhantomData<>::Output> - where D: Dimension + DimAdd + where + D: Dimension + DimAdd, { PhantomData } @@ -949,8 +910,7 @@ where { type Output = (ArrayViewMut<'a, A, I0::OutDim>,); - fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output - { + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { (view.slice_move(&self.0),) } @@ -1012,8 +972,7 @@ where { type Output = T::Output; - fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output - { + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { T::multi_slice_move(self, view) } diff --git a/src/split_at.rs b/src/split_at.rs index 5dee44b6..da29cf2a 100644 --- a/src/split_at.rs +++ b/src/split_at.rs @@ -1,19 +1,19 @@ use crate::imp_prelude::*; /// Arrays and similar that can be split along an axis -pub(crate) trait SplitAt -{ +pub(crate) trait SplitAt { fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - where Self: Sized; + where + Self: Sized; } -pub(crate) trait SplitPreference: SplitAt -{ +pub(crate) trait SplitPreference: SplitAt { #[allow(dead_code)] // used only when Rayon support is enabled fn can_split(&self) -> bool; fn split_preference(&self) -> (Axis, usize); fn split(self) -> (Self, Self) - where Self: Sized + where + Self: Sized, { let (axis, index) = self.split_preference(); self.split_at(axis, index) @@ -21,10 +21,10 @@ pub(crate) trait SplitPreference: SplitAt } impl SplitAt for D -where D: Dimension +where + D: Dimension, { - fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { let mut d1 = self; let mut d2 = d1.clone(); let i = axis.index(); @@ -36,19 +36,19 @@ where D: Dimension } impl SplitAt for ArrayViewMut<'_, A, D> -where D: Dimension +where + D: Dimension, { - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } impl SplitAt for RawArrayViewMut -where D: Dimension +where + D: Dimension, { - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } diff --git a/src/tri.rs b/src/tri.rs index 1e10c3b6..a98332fc 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -12,11 +12,7 @@ use num_traits::Zero; use crate::{ dimension::{is_layout_c, is_layout_f}, - Array, - ArrayRef, - Axis, - Dimension, - Zip, + Array, ArrayRef, Axis, Dimension, Zip, }; impl ArrayRef @@ -49,8 +45,7 @@ where /// ] /// ); /// ``` - pub fn triu(&self, k: isize) -> Array - { + pub fn triu(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } @@ -114,8 +109,7 @@ where /// ] /// ); /// ``` - pub fn tril(&self, k: isize) -> Array - { + pub fn tril(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } @@ -157,14 +151,12 @@ where } #[cfg(test)] -mod tests -{ +mod tests { use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; use alloc::vec; #[test] - fn test_keep_order() - { + fn test_keep_order() { let x = Array2::::ones((3, 3).f()); let res = x.triu(0); assert!(dimension::is_layout_f(&res.parts.dim, &res.parts.strides)); @@ -174,8 +166,7 @@ mod tests } #[test] - fn test_0d() - { + fn test_0d() { let x = Array0::::ones(()); let res = x.triu(0); assert_eq!(res, x); @@ -192,8 +183,7 @@ mod tests } #[test] - fn test_1d() - { + fn test_1d() { let x = array![1, 2, 3]; let res = x.triu(0); assert_eq!(res, x); @@ -210,8 +200,7 @@ mod tests } #[test] - fn test_2d() - { + fn test_2d() { let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; // Upper @@ -234,8 +223,7 @@ mod tests } #[test] - fn test_2d_single() - { + fn test_2d_single() { let x = array![[1]]; assert_eq!(x.triu(0), array![[1]]); @@ -247,8 +235,7 @@ mod tests } #[test] - fn test_3d() - { + fn test_3d() { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -307,8 +294,7 @@ mod tests } #[test] - fn test_off_axis() - { + fn test_off_axis() { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -337,8 +323,7 @@ mod tests } #[test] - fn test_odd_shape() - { + fn test_odd_shape() { let x = array![[1, 2, 3], [4, 5, 6]]; let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); @@ -355,8 +340,7 @@ mod tests } #[test] - fn test_odd_k() - { + fn test_odd_k() { let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; let z = Array2::zeros([3, 3]); assert_eq!(x.triu(isize::MIN), x); diff --git a/src/zip/mod.rs b/src/zip/mod.rs index b01ae04f..4ea93a85 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -39,7 +39,8 @@ macro_rules! fold_while { /// /// See [broadcasting](ArrayBase#broadcasting) for more information. trait Broadcast -where E: IntoDimension +where + E: IntoDimension, { type Output: NdProducer; /// Broadcast the array to the new dimensions `shape`. @@ -51,8 +52,7 @@ where E: IntoDimension } /// Compute `Layout` hints for array shape dim, strides -fn array_layout(dim: &D, strides: &D) -> Layout -{ +fn array_layout(dim: &D, strides: &D) -> Layout { let n = dim.ndim(); if dimension::is_layout_c(dim, strides) { // effectively one-dimensional => C and F layout compatible @@ -77,10 +77,10 @@ fn array_layout(dim: &D, strides: &D) -> Layout } impl LayoutRef -where D: Dimension +where + D: Dimension, { - pub(crate) fn layout_impl(&self) -> Layout - { + pub(crate) fn layout_impl(&self) -> Layout { array_layout(self._dim(), self._strides()) } } @@ -91,8 +91,7 @@ where D: Dimension, { type Output = ArrayView<'a, A, E::Dim>; - fn broadcast_unwrap(self, shape: E) -> Self::Output - { + fn broadcast_unwrap(self, shape: E) -> Self::Output { #[allow(clippy::needless_borrow)] let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension()); unsafe { ArrayView::new(res.parts.ptr, res.parts.dim, res.parts.strides) } @@ -100,8 +99,7 @@ where private_impl! {} } -trait ZippableTuple: Sized -{ +trait ZippableTuple: Sized { type Item; type Ptr: OffsetTuple + Copy; type Dim: Dimension; @@ -190,8 +188,7 @@ trait ZippableTuple: Sized /// ``` #[derive(Debug, Clone)] #[must_use = "zipping producers is lazy and does nothing unless consumed"] -pub struct Zip -{ +pub struct Zip { parts: Parts, dimension: D, layout: Layout, @@ -210,7 +207,8 @@ where /// The Zip will take the exact dimension of `p` and all inputs /// must have the same dimensions (or be broadcast to them). pub fn from(p: IP) -> Self - where IP: IntoNdProducer + where + IP: IntoNdProducer, { let array = p.into_producer(); let dim = array.raw_dim(); @@ -235,7 +233,8 @@ where /// /// *Note:* Indexed zip has overhead. pub fn indexed(p: IP) -> Self - where IP: IntoNdProducer + where + IP: IntoNdProducer, { let array = p.into_producer(); let dim = array.raw_dim(); @@ -258,11 +257,11 @@ where } impl Zip -where D: Dimension +where + D: Dimension, { /// Return a the number of element tuples in the Zip - pub fn size(&self) -> usize - { + pub fn size(&self) -> usize { self.dimension.size() } @@ -270,21 +269,18 @@ where D: Dimension /// /// ***Panics*** if `axis` is out of bounds. #[track_caller] - fn len_of(&self, axis: Axis) -> usize - { + fn len_of(&self, axis: Axis) -> usize { self.dimension[axis.index()] } - fn prefer_f(&self) -> bool - { + fn prefer_f(&self) -> bool { !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0) } /// Return an *approximation* to the max stride axis; if /// component arrays disagree, there may be no choice better than the /// others. - fn max_stride_axis(&self) -> Axis - { + fn max_stride_axis(&self) -> Axis { let i = if self.prefer_f() { self.dimension .slice() @@ -304,7 +300,8 @@ where D: Dimension } impl Zip -where D: Dimension +where + D: Dimension, { fn for_each_core(&mut self, acc: Acc, mut function: F) -> FoldWhile where @@ -425,8 +422,7 @@ where D: Dimension #[cfg(feature = "rayon")] #[allow(dead_code)] - pub(crate) fn uninitialized_for_current_layout(&self) -> Array, D> - { + pub(crate) fn uninitialized_for_current_layout(&self) -> Array, D> { let is_f = self.prefer_f(); Array::uninit(self.dimension.clone().set_f(is_f)) } @@ -441,13 +437,17 @@ where /// Debug assert traversal order is like c (including 1D case) // Method placement: only used for binary Zip at the moment. #[inline] - pub(crate) fn debug_assert_c_order(self) -> Self - { - debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 || - self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1, - "Assertion failed: traversal is not c-order or 1D for \ + pub(crate) fn debug_assert_c_order(self) -> Self { + debug_assert!( + self.layout.is(Layout::CORDER) + || self.layout_tendency >= 0 + || self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1, + "Assertion failed: traversal is not c-order or 1D for \ layout {:?}, tendency {}, dimension {:?}", - self.layout, self.layout_tendency, self.dimension); + self.layout, + self.layout_tendency, + self.dimension + ); self } } @@ -467,17 +467,14 @@ impl Offset for *mut T { } */ -trait OffsetTuple -{ +trait OffsetTuple { type Args; unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self; } -impl OffsetTuple for *mut T -{ +impl OffsetTuple for *mut T { type Args = isize; - unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self - { + unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self { self.offset(index as isize * stride) } } @@ -929,27 +926,23 @@ map_impl! { /// Value controlling the execution of `.fold_while` on `Zip`. #[derive(Debug, Copy, Clone)] -pub enum FoldWhile -{ +pub enum FoldWhile { /// Continue folding with this value Continue(T), /// Fold is complete and will return this value Done(T), } -impl FoldWhile -{ +impl FoldWhile { /// Return the inner value - pub fn into_inner(self) -> T - { + pub fn into_inner(self) -> T { match self { FoldWhile::Continue(x) | FoldWhile::Done(x) => x, } } /// Return true if it is `Done`, false if `Continue` - pub fn is_done(&self) -> bool - { + pub fn is_done(&self) -> bool { match *self { FoldWhile::Continue(_) => false, FoldWhile::Done(_) => true, diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs index fe666e81..a2471d9e 100644 --- a/src/zip/ndproducer.rs +++ b/src/zip/ndproducer.rs @@ -10,8 +10,7 @@ use alloc::vec::Vec; /// Slices and vectors can be used (equivalent to 1-dimensional array views). /// /// This trait is like `IntoIterator` for `NdProducers` instead of iterators. -pub trait IntoNdProducer -{ +pub trait IntoNdProducer { /// The element produced per iteration. type Item; /// Dimension type of the producer @@ -23,13 +22,13 @@ pub trait IntoNdProducer } impl

IntoNdProducer for P -where P: NdProducer +where + P: NdProducer, { type Item = P::Item; type Dim = P::Dim; type Output = Self; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { self } } @@ -54,8 +53,7 @@ where P: NdProducer /// *producing* multidimensional items). /// /// See also [`IntoNdProducer`] -pub trait NdProducer -{ +pub trait NdProducer { /// The element produced per iteration. type Item; // Internal use / Pointee type @@ -78,8 +76,7 @@ pub trait NdProducer /// Return the shape of the producer. fn raw_dim(&self) -> Self::Dim; #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self.raw_dim() == *dim } #[doc(hidden)] @@ -94,33 +91,29 @@ pub trait NdProducer fn contiguous_stride(&self) -> Self::Stride; #[doc(hidden)] fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - where Self: Sized; + where + Self: Sized; private_decl! {} } -pub trait Offset: Copy -{ +pub trait Offset: Copy { type Stride: Copy; unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self; private_decl! {} } -impl Offset for *const T -{ +impl Offset for *const T { type Stride = isize; - unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self - { + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { self.offset(s * (index as isize)) } private_impl! {} } -impl Offset for *mut T -{ +impl Offset for *mut T { type Stride = isize; - unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self - { + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { self.offset(s * (index as isize)) } private_impl! {} @@ -136,8 +129,7 @@ where type Item = &'a A; type Dim = D; type Output = ArrayView<'a, A, D>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { self.view() } } @@ -152,8 +144,7 @@ where type Item = &'a mut A; type Dim = D; type Output = ArrayViewMut<'a, A, D>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { self.view_mut() } } @@ -161,13 +152,13 @@ where /// An array reference is an n-dimensional producer of element references /// (like ArrayView). impl<'a, A: 'a, D> IntoNdProducer for &'a ArrayRef -where D: Dimension +where + D: Dimension, { type Item = &'a A; type Dim = D; type Output = ArrayView<'a, A, D>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { self.view() } } @@ -175,91 +166,78 @@ where D: Dimension /// A mutable array reference is an n-dimensional producer of mutable element /// references (like ArrayViewMut). impl<'a, A: 'a, D> IntoNdProducer for &'a mut ArrayRef -where D: Dimension +where + D: Dimension, { type Item = &'a mut A; type Dim = D; type Output = ArrayViewMut<'a, A, D>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { self.view_mut() } } /// A slice is a one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a [A] -{ +impl<'a, A: 'a> IntoNdProducer for &'a [A] { type Item = ::Item; type Dim = Ix1; type Output = ArrayView1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } /// A mutable slice is a mutable one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a mut [A] -{ +impl<'a, A: 'a> IntoNdProducer for &'a mut [A] { type Item = ::Item; type Dim = Ix1; type Output = ArrayViewMut1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } /// A one-dimensional array is a one-dimensional producer -impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a [A; N] -{ +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a [A; N] { type Item = ::Item; type Dim = Ix1; type Output = ArrayView1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } /// A mutable one-dimensional array is a mutable one-dimensional producer -impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a mut [A; N] -{ +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a mut [A; N] { type Item = ::Item; type Dim = Ix1; type Output = ArrayViewMut1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } /// A Vec is a one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a Vec -{ +impl<'a, A: 'a> IntoNdProducer for &'a Vec { type Item = ::Item; type Dim = Ix1; type Output = ArrayView1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } /// A mutable Vec is a mutable one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a mut Vec -{ +impl<'a, A: 'a> IntoNdProducer for &'a mut Vec { type Item = ::Item; type Dim = Ix1; type Output = ArrayViewMut1<'a, A>; - fn into_producer(self) -> Self::Output - { + fn into_producer(self) -> Self::Output { <_>::from(self) } } -impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> -{ +impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> { type Item = &'a A; type Dim = D; type Ptr = *mut A; @@ -267,57 +245,47 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> private_impl! {} - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { (***self).raw_dim() } - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self._dim().equal(dim) } - fn as_ptr(&self) -> *mut A - { + fn as_ptr(&self) -> *mut A { (**self).as_ptr() as _ } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { self.layout_impl() } - unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item - { + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item { &*ptr } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { self._ptr() .as_ptr() .offset(i.index_unchecked(self._strides())) } - fn stride_of(&self, axis: Axis) -> isize - { + fn stride_of(&self, axis: Axis) -> isize { (**self).stride_of(axis) } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { 1 } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } -impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> -{ +impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> { type Item = &'a mut A; type Dim = D; type Ptr = *mut A; @@ -325,57 +293,47 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> private_impl! {} - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { (***self).raw_dim() } - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self._dim().equal(dim) } - fn as_ptr(&self) -> *mut A - { + fn as_ptr(&self) -> *mut A { (**self).as_ptr() as _ } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { self.layout_impl() } - unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item - { + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item { &mut *ptr } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { self._ptr() .as_ptr() .offset(i.index_unchecked(self._strides())) } - fn stride_of(&self, axis: Axis) -> isize - { + fn stride_of(&self, axis: Axis) -> isize { (**self).stride_of(axis) } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { 1 } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } -impl NdProducer for RawArrayView -{ +impl NdProducer for RawArrayView { type Item = *const A; type Dim = D; type Ptr = *const A; @@ -383,58 +341,48 @@ impl NdProducer for RawArrayView private_impl! {} - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { self.raw_dim() } - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self.parts.dim.equal(dim) } - fn as_ptr(&self) -> *const A - { + fn as_ptr(&self) -> *const A { self.as_ptr() as _ } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { AsRef::>::as_ref(self).layout_impl() } - unsafe fn as_ref(&self, ptr: *const A) -> *const A - { + unsafe fn as_ref(&self, ptr: *const A) -> *const A { ptr } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { self.parts .ptr .as_ptr() .offset(i.index_unchecked(&self.parts.strides)) } - fn stride_of(&self, axis: Axis) -> isize - { + fn stride_of(&self, axis: Axis) -> isize { self.stride_of(axis) } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { 1 } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } -impl NdProducer for RawArrayViewMut -{ +impl NdProducer for RawArrayViewMut { type Item = *mut A; type Dim = D; type Ptr = *mut A; @@ -442,52 +390,43 @@ impl NdProducer for RawArrayViewMut private_impl! {} - fn raw_dim(&self) -> Self::Dim - { + fn raw_dim(&self) -> Self::Dim { self.raw_dim() } - fn equal_dim(&self, dim: &Self::Dim) -> bool - { + fn equal_dim(&self, dim: &Self::Dim) -> bool { self.parts.dim.equal(dim) } - fn as_ptr(&self) -> *mut A - { + fn as_ptr(&self) -> *mut A { self.as_ptr() as _ } - fn layout(&self) -> Layout - { + fn layout(&self) -> Layout { AsRef::>::as_ref(self).layout_impl() } - unsafe fn as_ref(&self, ptr: *mut A) -> *mut A - { + unsafe fn as_ref(&self, ptr: *mut A) -> *mut A { ptr } - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A - { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { self.parts .ptr .as_ptr() .offset(i.index_unchecked(&self.parts.strides)) } - fn stride_of(&self, axis: Axis) -> isize - { + fn stride_of(&self, axis: Axis) -> isize { self.stride_of(axis) } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride - { + fn contiguous_stride(&self) -> Self::Stride { 1 } - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { self.split_at(axis, index) } } diff --git a/tests/append.rs b/tests/append.rs index 40beb0f9..a10718c2 100644 --- a/tests/append.rs +++ b/tests/append.rs @@ -2,47 +2,34 @@ use ndarray::prelude::*; use ndarray::{ErrorKind, ShapeError}; #[test] -fn push_row() -{ +fn push_row() { let mut a = Array::zeros((0, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[2, 4]); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); - - assert_eq!(a.push_row(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1., 2.])), - Ok(())); - assert_eq!(a, - array![[0., 1., 2., 3., 1.], - [4., 5., 6., 7., 2.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); + + assert_eq!(a.push_row(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2.])), Ok(())); + assert_eq!(a, array![[0., 1., 2., 3., 1.], [4., 5., 6., 7., 2.]]); } #[test] -fn push_row_wrong_layout() -{ +fn push_row_wrong_layout() { let mut a = Array::zeros((0, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[2, 4]); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); assert_eq!(a.strides(), &[4, 1]); // Changing the memory layout to fit the next append let mut a2 = a.clone(); a2.push_column(aview1(&[1., 2.])).unwrap(); - assert_eq!(a2, - array![[0., 1., 2., 3., 1.], - [4., 5., 6., 7., 2.]]); + assert_eq!(a2, array![[0., 1., 2., 3., 1.], [4., 5., 6., 7., 2.]]); assert_eq!(a2.strides(), &[1, 2]); // Clone the array @@ -52,22 +39,17 @@ fn push_row_wrong_layout() let mut b = Array::zeros(dim); b.append(Axis(1), a.view()).unwrap(); assert_eq!(b.push_column(aview1(&[1., 2.])), Ok(())); - assert_eq!(b, - array![[0., 1., 2., 3., 1.], - [4., 5., 6., 7., 2.]]); + assert_eq!(b, array![[0., 1., 2., 3., 1.], [4., 5., 6., 7., 2.]]); } #[test] -fn push_row_neg_stride_1() -{ +fn push_row_neg_stride_1() { let mut a = Array::zeros((0, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[2, 4]); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); assert_eq!(a.strides(), &[4, 1]); a.invert_axis(Axis(0)); @@ -77,41 +59,30 @@ fn push_row_neg_stride_1() println!("a = {:?}", a); println!("a2 = {:?}", a2); a2.push_column(aview1(&[1., 2.])).unwrap(); - assert_eq!(a2, - array![[4., 5., 6., 7., 1.], - [0., 1., 2., 3., 2.]]); + assert_eq!(a2, array![[4., 5., 6., 7., 1.], [0., 1., 2., 3., 2.]]); assert_eq!(a2.strides(), &[1, 2]); a.invert_axis(Axis(1)); let mut a3 = a.clone(); a3.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); - assert_eq!(a3, - array![[7., 6., 5., 4.], - [3., 2., 1., 0.], - [4., 5., 6., 7.]]); + assert_eq!(a3, array![[7., 6., 5., 4.], [3., 2., 1., 0.], [4., 5., 6., 7.]]); assert_eq!(a3.strides(), &[4, 1]); a.invert_axis(Axis(0)); let mut a4 = a.clone(); a4.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); - assert_eq!(a4, - array![[3., 2., 1., 0.], - [7., 6., 5., 4.], - [4., 5., 6., 7.]]); + assert_eq!(a4, array![[3., 2., 1., 0.], [7., 6., 5., 4.], [4., 5., 6., 7.]]); assert_eq!(a4.strides(), &[4, -1]); } #[test] -fn push_row_neg_stride_2() -{ +fn push_row_neg_stride_2() { let mut a = Array::zeros((0, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[2, 4]); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); assert_eq!(a.strides(), &[4, 1]); a.invert_axis(Axis(1)); @@ -121,108 +92,73 @@ fn push_row_neg_stride_2() println!("a = {:?}", a); println!("a2 = {:?}", a2); a2.push_column(aview1(&[1., 2.])).unwrap(); - assert_eq!(a2, - array![[3., 2., 1., 0., 1.], - [7., 6., 5., 4., 2.]]); + assert_eq!(a2, array![[3., 2., 1., 0., 1.], [7., 6., 5., 4., 2.]]); assert_eq!(a2.strides(), &[1, 2]); a.invert_axis(Axis(0)); let mut a3 = a.clone(); a3.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); - assert_eq!(a3, - array![[7., 6., 5., 4.], - [3., 2., 1., 0.], - [4., 5., 6., 7.]]); + assert_eq!(a3, array![[7., 6., 5., 4.], [3., 2., 1., 0.], [4., 5., 6., 7.]]); assert_eq!(a3.strides(), &[4, 1]); a.invert_axis(Axis(1)); let mut a4 = a.clone(); a4.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); - assert_eq!(a4, - array![[4., 5., 6., 7.], - [0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a4, array![[4., 5., 6., 7.], [0., 1., 2., 3.], [4., 5., 6., 7.]]); assert_eq!(a4.strides(), &[4, 1]); } #[test] -fn push_row_error() -{ +fn push_row_error() { let mut a = Array::zeros((3, 4)); - assert_eq!(a.push_row(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1., 2., 3.])), - Ok(())); - assert_eq!(a.t(), - array![[0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [1., 2., 3.]]); + assert_eq!(a.push_row(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2., 3.])), Ok(())); + assert_eq!(a.t(), array![[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [1., 2., 3.]]); } #[test] -fn push_row_existing() -{ +fn push_row_existing() { let mut a = Array::zeros((1, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[3, 4]); - assert_eq!(a, - array![[0., 0., 0., 0.], - [0., 1., 2., 3.], - [4., 5., 6., 7.]]); - - assert_eq!(a.push_row(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); - assert_eq!(a.push_column(aview1(&[1., 2., 3.])), - Ok(())); - assert_eq!(a, - array![[0., 0., 0., 0., 1.], - [0., 1., 2., 3., 2.], - [4., 5., 6., 7., 3.]]); + assert_eq!(a, array![[0., 0., 0., 0.], [0., 1., 2., 3.], [4., 5., 6., 7.]]); + + assert_eq!(a.push_row(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2., 3.])), Ok(())); + assert_eq!(a, array![[0., 0., 0., 0., 1.], [0., 1., 2., 3., 2.], [4., 5., 6., 7., 3.]]); } #[test] -fn push_row_col_len_1() -{ +fn push_row_col_len_1() { // Test appending 1 row and then cols from shape 1 x 1 let mut a = Array::zeros((1, 1)); a.push_row(aview1(&[1.])).unwrap(); // shape 2 x 1 a.push_column(aview1(&[2., 3.])).unwrap(); // shape 2 x 2 - assert_eq!(a.push_row(aview1(&[1.])), - Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_row(aview1(&[1.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); //assert_eq!(a.push_row(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))); a.push_column(aview1(&[4., 5.])).unwrap(); // shape 2 x 3 assert_eq!(a.shape(), &[2, 3]); - assert_eq!(a, - array![[0., 2., 4.], - [1., 3., 5.]]); + assert_eq!(a, array![[0., 2., 4.], [1., 3., 5.]]); } #[test] -fn push_column() -{ +fn push_column() { let mut a = Array::zeros((4, 0)); a.push_column(aview1(&[0., 1., 2., 3.])).unwrap(); a.push_column(aview1(&[4., 5., 6., 7.])).unwrap(); assert_eq!(a.shape(), &[4, 2]); - assert_eq!(a.t(), - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a.t(), array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); } #[test] -fn append_array1() -{ +fn append_array1() { let mut a = Array::zeros((0, 4)); a.append(Axis(0), aview2(&[[0., 1., 2., 3.]])).unwrap(); println!("{:?}", a); @@ -231,23 +167,16 @@ fn append_array1() //a.push_column(aview1(&[4., 5., 6., 7.])).unwrap(); //assert_eq!(a.shape(), &[4, 2]); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.]]); a.append(Axis(0), aview2(&[[5., 5., 4., 4.], [3., 3., 2., 2.]])) .unwrap(); println!("{:?}", a); - assert_eq!(a, - array![[0., 1., 2., 3.], - [4., 5., 6., 7.], - [5., 5., 4., 4.], - [3., 3., 2., 2.]]); + assert_eq!(a, array![[0., 1., 2., 3.], [4., 5., 6., 7.], [5., 5., 4., 4.], [3., 3., 2., 2.]]); } #[test] -fn append_array_3d() -{ +fn append_array_3d() { let mut a = Array::zeros((0, 2, 2)); a.append(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap(); println!("{:?}", a); @@ -270,26 +199,17 @@ fn append_array_3d() println!("Send {:?} to append", av); a.append(Axis(1), av).unwrap(); println!("{:?}", a); - assert_eq!(a, - array![[[0, 1], - [51, 52], - [55, 56], - [71, 72], - [75, 76], - [81, 82], - [85, 86]], - [[2, 3], - [53, 54], - [57, 58], - [73, 74], - [77, 78], - [83, 84], - [87, 88]]]); + assert_eq!( + a, + array![ + [[0, 1], [51, 52], [55, 56], [71, 72], [75, 76], [81, 82], [85, 86]], + [[2, 3], [53, 54], [57, 58], [73, 74], [77, 78], [83, 84], [87, 88]] + ] + ); } #[test] -fn test_append_2d() -{ +fn test_append_2d() { // create an empty array and append let mut a = Array::zeros((0, 4)); let ones = ArrayView::from(&[1.; 12]) @@ -325,8 +245,7 @@ fn test_append_2d() } #[test] -fn test_append_middle_axis() -{ +fn test_append_middle_axis() { // ensure we can append to Axis(1) by letting it become outermost let mut a = Array::::zeros((3, 0, 2)); a.append( @@ -371,8 +290,7 @@ fn test_append_middle_axis() } #[test] -fn test_append_zero_size() -{ +fn test_append_zero_size() { { let mut a = Array::::zeros((0, 0)); a.append(Axis(0), aview2(&[[]])).unwrap(); @@ -393,8 +311,7 @@ fn test_append_zero_size() } #[test] -fn push_row_neg_stride_3() -{ +fn push_row_neg_stride_3() { let mut a = Array::zeros((0, 4)); a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); a.invert_axis(Axis(1)); @@ -405,8 +322,7 @@ fn push_row_neg_stride_3() } #[test] -fn push_row_ignore_strides_length_one_axes() -{ +fn push_row_ignore_strides_length_one_axes() { let strides = &[0, 1, 10, 20]; for invert in &[vec![], vec![0], vec![1], vec![0, 1]] { for &stride0 in strides { @@ -426,23 +342,20 @@ fn push_row_ignore_strides_length_one_axes() #[test] #[should_panic(expected = "IncompatibleShape")] -fn zero_dimensional_error1() -{ +fn zero_dimensional_error1() { let mut a = Array::zeros(()).into_dyn(); a.append(Axis(0), arr0(0).into_dyn().view()).unwrap(); } #[test] #[should_panic(expected = "IncompatibleShape")] -fn zero_dimensional_error2() -{ +fn zero_dimensional_error2() { let mut a = Array::zeros(()).into_dyn(); a.push(Axis(0), arr0(0).into_dyn().view()).unwrap(); } #[test] -fn zero_dimensional_ok() -{ +fn zero_dimensional_ok() { let mut a = Array::zeros(0); let one = aview0(&1); let two = aview0(&2); diff --git a/tests/array-construct.rs b/tests/array-construct.rs index ec8cedf3..73d92648 100644 --- a/tests/array-construct.rs +++ b/tests/array-construct.rs @@ -6,16 +6,14 @@ use ndarray::prelude::*; use ndarray::Zip; #[test] -fn test_from_shape_fn() -{ +fn test_from_shape_fn() { let step = 3.1; let h = Array::from_shape_fn((5, 5), |(i, j)| f64::sin(i as f64 / step) * f64::cos(j as f64 / step)); assert_eq!(h.shape(), &[5, 5]); } #[test] -fn test_dimension_zero() -{ +fn test_dimension_zero() { let a: Array2 = Array2::from(vec![[], [], []]); assert_eq!((vec![0.; 0], None), a.into_raw_vec_and_offset()); let a: Array3 = Array3::from(vec![[[]], [[]], [[]]]); @@ -24,8 +22,7 @@ fn test_dimension_zero() #[test] #[cfg(feature = "approx")] -fn test_arc_into_owned() -{ +fn test_arc_into_owned() { use approx::assert_abs_diff_ne; let a = Array2::from_elem((5, 5), 1.).into_shared(); @@ -38,8 +35,7 @@ fn test_arc_into_owned() } #[test] -fn test_arcarray_thread_safe() -{ +fn test_arcarray_thread_safe() { fn is_send(_t: &T) {} fn is_sync(_t: &T) {} let a = Array2::from_elem((5, 5), 1.).into_shared(); @@ -49,8 +45,7 @@ fn test_arcarray_thread_safe() } #[test] -fn test_from_fn_c0() -{ +fn test_from_fn_c0() { let a = Array::from_shape_fn((), |i| i); assert_eq!(a[()], ()); assert_eq!(a.len(), 1); @@ -58,8 +53,7 @@ fn test_from_fn_c0() } #[test] -fn test_from_fn_c1() -{ +fn test_from_fn_c1() { let a = Array::from_shape_fn(28, |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -67,8 +61,7 @@ fn test_from_fn_c1() } #[test] -fn test_from_fn_c() -{ +fn test_from_fn_c() { let a = Array::from_shape_fn((4, 7), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -76,8 +69,7 @@ fn test_from_fn_c() } #[test] -fn test_from_fn_c3() -{ +fn test_from_fn_c3() { let a = Array::from_shape_fn((4, 3, 7), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -85,8 +77,7 @@ fn test_from_fn_c3() } #[test] -fn test_from_fn_f0() -{ +fn test_from_fn_f0() { let a = Array::from_shape_fn(().f(), |i| i); assert_eq!(a[()], ()); assert_eq!(a.len(), 1); @@ -94,8 +85,7 @@ fn test_from_fn_f0() } #[test] -fn test_from_fn_f1() -{ +fn test_from_fn_f1() { let a = Array::from_shape_fn(28.f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -103,8 +93,7 @@ fn test_from_fn_f1() } #[test] -fn test_from_fn_f() -{ +fn test_from_fn_f() { let a = Array::from_shape_fn((4, 7).f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -112,8 +101,7 @@ fn test_from_fn_f() } #[test] -fn test_from_fn_f_with_zero() -{ +fn test_from_fn_f_with_zero() { defmac!(test_from_fn_f_with_zero shape => { let a = Array::from_shape_fn(shape.f(), |i| i); assert_eq!(a.len(), 0); @@ -128,8 +116,7 @@ fn test_from_fn_f_with_zero() } #[test] -fn test_from_fn_f3() -{ +fn test_from_fn_f3() { let a = Array::from_shape_fn((4, 2, 7).f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -137,8 +124,7 @@ fn test_from_fn_f3() } #[test] -fn deny_wraparound_from_vec() -{ +fn deny_wraparound_from_vec() { let five = vec![0; 5]; let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone()); println!("{:?}", five_large); @@ -148,8 +134,7 @@ fn deny_wraparound_from_vec() } #[test] -fn test_ones() -{ +fn test_ones() { let mut a = Array::::zeros((2, 3, 4)); a.fill(1.0); let b = Array::::ones((2, 3, 4)); @@ -157,8 +142,7 @@ fn test_ones() } #[test] -fn test_from_shape_empty_with_neg_stride() -{ +fn test_from_shape_empty_with_neg_stride() { // Issue #998, negative strides for an axis where it doesn't matter. let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; let v = s[..12].to_vec(); @@ -169,45 +153,37 @@ fn test_from_shape_empty_with_neg_stride() } #[test] -fn test_from_shape_with_neg_stride() -{ +fn test_from_shape_with_neg_stride() { // Issue #998, negative strides for an axis where it doesn't matter. let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; let v = s[..12].to_vec(); let v_ptr = v.as_ptr(); let a = Array::from_shape_vec((2, 1, 2).strides((1, -4isize as usize, 2)), v).unwrap(); - assert_eq!(a, arr3(&[[[0, 2]], - [[1, 3]]])); + assert_eq!(a, arr3(&[[[0, 2]], [[1, 3]]])); assert_eq!(a.as_ptr(), v_ptr); } #[test] -fn test_from_shape_2_2_2_with_neg_stride() -{ +fn test_from_shape_2_2_2_with_neg_stride() { // Issue #998, negative strides for an axis where it doesn't matter. let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; let v = s[..12].to_vec(); let v_ptr = v.as_ptr(); let a = Array::from_shape_vec((2, 2, 2).strides((1, -4isize as usize, 2)), v).unwrap(); - assert_eq!(a, arr3(&[[[4, 6], - [0, 2]], - [[5, 7], - [1, 3]]])); + assert_eq!(a, arr3(&[[[4, 6], [0, 2]], [[5, 7], [1, 3]]])); assert_eq!(a.as_ptr(), v_ptr.wrapping_add(4)); } #[should_panic] #[test] -fn deny_wraparound_zeros() -{ +fn deny_wraparound_zeros() { //2^64 + 5 = 18446744073709551621 = 3×7×29×36760123×823996703 (5 distinct prime factors) let _five_large = Array::::zeros((3, 7, 29, 36760123, 823996703)); } #[should_panic] #[test] -fn deny_wraparound_reshape() -{ +fn deny_wraparound_reshape() { //2^64 + 5 = 18446744073709551621 = 3×7×29×36760123×823996703 (5 distinct prime factors) let five = Array::::zeros(5); let _five_large = five @@ -217,56 +193,48 @@ fn deny_wraparound_reshape() #[should_panic] #[test] -fn deny_wraparound_default() -{ +fn deny_wraparound_default() { let _five_large = Array::::default((3, 7, 29, 36760123, 823996703)); } #[should_panic] #[test] -fn deny_wraparound_from_shape_fn() -{ +fn deny_wraparound_from_shape_fn() { let _five_large = Array::::from_shape_fn((3, 7, 29, 36760123, 823996703), |_| 0.); } #[should_panic] #[test] -fn deny_wraparound_uninit() -{ +fn deny_wraparound_uninit() { let _five_large = Array::::uninit((3, 7, 29, 36760123, 823996703)); } #[should_panic] #[test] -fn deny_slice_with_too_many_rows_to_arrayview2() -{ +fn deny_slice_with_too_many_rows_to_arrayview2() { let _view = ArrayView2::from(&[[0u8; 0]; usize::MAX][..]); } #[should_panic] #[test] -fn deny_slice_with_too_many_zero_sized_elems_to_arrayview2() -{ +fn deny_slice_with_too_many_zero_sized_elems_to_arrayview2() { let _view = ArrayView2::from(&[[(); isize::MAX as usize]; isize::MAX as usize][..]); } #[should_panic] #[test] -fn deny_slice_with_too_many_rows_to_arrayviewmut2() -{ +fn deny_slice_with_too_many_rows_to_arrayviewmut2() { let _view = ArrayViewMut2::from(&mut [[0u8; 0]; usize::MAX][..]); } #[should_panic] #[test] -fn deny_slice_with_too_many_zero_sized_elems_to_arrayviewmut2() -{ +fn deny_slice_with_too_many_zero_sized_elems_to_arrayviewmut2() { let _view = ArrayViewMut2::from(&mut [[(); isize::MAX as usize]; isize::MAX as usize][..]); } #[test] -fn maybe_uninit_1() -{ +fn maybe_uninit_1() { use std::mem::MaybeUninit; unsafe { diff --git a/tests/array.rs b/tests/array.rs index 391f88b9..bab3563d 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -31,8 +31,7 @@ macro_rules! assert_panics { } #[test] -fn test_matmul_arcarray() -{ +fn test_matmul_arcarray() { let mut A = ArcArray::::zeros((2, 3)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -56,20 +55,17 @@ fn test_matmul_arcarray() } #[allow(unused)] -fn arrayview_shrink_lifetime<'a, 'b: 'a>(view: ArrayView1<'b, f64>) -> ArrayView1<'a, f64> -{ +fn arrayview_shrink_lifetime<'a, 'b: 'a>(view: ArrayView1<'b, f64>) -> ArrayView1<'a, f64> { view.reborrow() } #[allow(unused)] -fn arrayviewmut_shrink_lifetime<'a, 'b: 'a>(view: ArrayViewMut1<'b, f64>) -> ArrayViewMut1<'a, f64> -{ +fn arrayviewmut_shrink_lifetime<'a, 'b: 'a>(view: ArrayViewMut1<'b, f64>) -> ArrayViewMut1<'a, f64> { view.reborrow() } #[test] -fn test_mat_mul() -{ +fn test_mat_mul() { // smoke test, a big matrix multiplication of uneven size let (n, m) = (45, 33); let a = Array::from_iter(0..(n * m)) @@ -83,8 +79,7 @@ fn test_mat_mul() #[deny(unsafe_code)] #[test] -fn test_slice() -{ +fn test_slice() { let mut A = ArcArray::::zeros((3, 4, 5)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -99,15 +94,13 @@ fn test_slice() #[deny(unsafe_code)] #[test] -fn test_slice_ix0() -{ +fn test_slice_ix0() { let arr = arr0(5); assert_eq!(arr.slice(s![]), aview0(&5)); } #[test] -fn test_slice_edge_cases() -{ +fn test_slice_edge_cases() { let mut arr = Array3::::zeros((3, 4, 5)); arr.slice_collapse(s![0..0;-1, .., ..]); assert_eq!(arr.shape(), &[0, 4, 5]); @@ -117,8 +110,7 @@ fn test_slice_edge_cases() } #[test] -fn test_slice_inclusive_range() -{ +fn test_slice_inclusive_range() { let arr = array![[1, 2, 3], [4, 5, 6]]; assert_eq!(arr.slice(s![1..=1, 1..=2]), array![[5, 6]]); assert_eq!(arr.slice(s![1..=-1, -2..=2;-1]), array![[6, 5]]); @@ -132,8 +124,7 @@ fn test_slice_inclusive_range() /// `ArrayView1` and `ArrayView2`, so the compiler needs to determine which /// type is the correct result for the `.slice()` call. #[test] -fn test_slice_infer() -{ +fn test_slice_infer() { let a = array![1., 2.]; let b = array![[3., 4.], [5., 6.]]; b.slice(s![..-1, ..]).dot(&a); @@ -141,8 +132,7 @@ fn test_slice_infer() } #[test] -fn test_slice_with_many_dim() -{ +fn test_slice_with_many_dim() { let mut A = ArcArray::::zeros(&[3, 1, 4, 1, 3, 2, 1][..]); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -169,16 +159,14 @@ fn test_slice_with_many_dim() } #[test] -fn test_slice_range_variable() -{ +fn test_slice_range_variable() { let range = 1..4; let arr = array![0, 1, 2, 3, 4]; assert_eq!(arr.slice(s![range]), array![1, 2, 3]); } #[test] -fn test_slice_args_eval_range_once() -{ +fn test_slice_args_eval_range_once() { let mut eval_count = 0; { let mut range = || { @@ -192,8 +180,7 @@ fn test_slice_args_eval_range_once() } #[test] -fn test_slice_args_eval_step_once() -{ +fn test_slice_args_eval_step_once() { let mut eval_count = 0; { let mut step = || { @@ -207,8 +194,7 @@ fn test_slice_args_eval_step_once() } #[test] -fn test_slice_array_fixed() -{ +fn test_slice_array_fixed() { let mut arr = Array3::::zeros((5, 2, 5)); let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); @@ -219,8 +205,7 @@ fn test_slice_array_fixed() } #[test] -fn test_slice_dyninput_array_fixed() -{ +fn test_slice_dyninput_array_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); @@ -231,8 +216,7 @@ fn test_slice_dyninput_array_fixed() } #[test] -fn test_slice_array_dyn() -{ +fn test_slice_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)); let info = SliceInfo::<_, Ix3, IxDyn>::try_from([ SliceInfoElem::from(1..), @@ -254,8 +238,7 @@ fn test_slice_array_dyn() } #[test] -fn test_slice_dyninput_array_dyn() -{ +fn test_slice_dyninput_array_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = SliceInfo::<_, Ix3, IxDyn>::try_from([ SliceInfoElem::from(1..), @@ -277,8 +260,7 @@ fn test_slice_dyninput_array_dyn() } #[test] -fn test_slice_dyninput_vec_fixed() -{ +fn test_slice_dyninput_vec_fixed() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = &SliceInfo::<_, Ix3, Ix3>::try_from(vec![ SliceInfoElem::from(1..), @@ -300,8 +282,7 @@ fn test_slice_dyninput_vec_fixed() } #[test] -fn test_slice_dyninput_vec_dyn() -{ +fn test_slice_dyninput_vec_dyn() { let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); let info = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![ SliceInfoElem::from(1..), @@ -323,8 +304,7 @@ fn test_slice_dyninput_vec_dyn() } #[test] -fn test_slice_with_subview_and_new_axis() -{ +fn test_slice_with_subview_and_new_axis() { let mut arr = ArcArray::::zeros((3, 5, 4)); for (i, elt) in arr.iter_mut().enumerate() { *elt = i; @@ -361,8 +341,7 @@ fn test_slice_with_subview_and_new_axis() } #[test] -fn test_slice_collapse_with_indices() -{ +fn test_slice_collapse_with_indices() { let mut arr = ArcArray::::zeros((3, 5, 4)); for (i, elt) in arr.iter_mut().enumerate() { *elt = i; @@ -401,15 +380,13 @@ fn test_slice_collapse_with_indices() #[test] #[should_panic] -fn test_slice_collapse_with_newaxis() -{ +fn test_slice_collapse_with_newaxis() { let mut arr = Array2::::zeros((2, 3)); arr.slice_collapse(s![0, 0, NewAxis]); } #[test] -fn test_multislice() -{ +fn test_multislice() { macro_rules! do_test { ($arr:expr, $($s:expr),*) => { { @@ -427,10 +404,7 @@ fn test_multislice() .into_shape_with_order((8, 6)) .unwrap(); - assert_eq!( - (arr.clone().view_mut(),), - arr.multi_slice_mut((s![.., ..],)), - ); + assert_eq!((arr.clone().view_mut(),), arr.multi_slice_mut((s![.., ..],)),); assert_eq!(arr.multi_slice_mut(()), ()); do_test!(&mut arr, s![0, ..]); do_test!(&mut arr, s![0, ..], s![1, ..]); @@ -447,8 +421,7 @@ fn test_multislice() } #[test] -fn test_multislice_intersecting() -{ +fn test_multislice_intersecting() { assert_panics!({ let mut arr = Array2::::zeros((8, 6)); arr.multi_slice_mut((s![3, .., NewAxis], s![3, ..])); @@ -489,39 +462,34 @@ fn test_multislice_intersecting() #[should_panic] #[test] -fn index_out_of_bounds() -{ +fn index_out_of_bounds() { let mut a = Array::::zeros((3, 4)); a[[3, 2]] = 1; } #[should_panic] #[test] -fn slice_oob() -{ +fn slice_oob() { let a = ArcArray::::zeros((3, 4)); let _vi = a.slice(s![..10, ..]); } #[should_panic] #[test] -fn slice_axis_oob() -{ +fn slice_axis_oob() { let a = ArcArray::::zeros((3, 4)); let _vi = a.slice_axis(Axis(0), Slice::new(0, Some(10), 1)); } #[should_panic] #[test] -fn slice_wrong_dim() -{ +fn slice_wrong_dim() { let a = ArcArray::::zeros(vec![3, 4, 5]); let _vi = a.slice(s![.., ..]); } #[test] -fn test_index() -{ +fn test_index() { let mut A = ArcArray::::zeros((2, 3)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -542,8 +510,7 @@ fn test_index() } #[test] -fn test_index_arrays() -{ +fn test_index_arrays() { let a = Array1::from_iter(0..12); assert_eq!(a[1], a[[1]]); let v = a.view().into_shape_with_order((3, 4)).unwrap(); @@ -554,8 +521,7 @@ fn test_index_arrays() #[test] #[allow(clippy::assign_op_pattern)] -fn test_add() -{ +fn test_add() { let mut A = ArcArray::::zeros((2, 2)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -570,8 +536,7 @@ fn test_add() } #[test] -fn test_multidim() -{ +fn test_multidim() { let mut mat = ArcArray::zeros(2 * 3 * 4 * 5 * 6) .into_shape_with_order((2, 3, 4, 5, 6)) .unwrap(); @@ -596,8 +561,7 @@ array([[[ 7, 6], [ 9, 8]]]) */ #[test] -fn test_negative_stride_arcarray() -{ +fn test_negative_stride_arcarray() { let mut mat = ArcArray::zeros((2, 4, 2)); mat[[0, 0, 0]] = 1.0f32; for (i, elt) in mat.iter_mut().enumerate() { @@ -623,8 +587,7 @@ fn test_negative_stride_arcarray() } #[test] -fn test_cow() -{ +fn test_cow() { let mut mat = ArcArray::zeros((2, 2)); mat[[0, 0]] = 1; let n = mat.clone(); @@ -656,8 +619,7 @@ fn test_cow() } #[test] -fn test_cow_shrink() -{ +fn test_cow_shrink() { // A test for clone-on-write in the case that // mutation shrinks the array and gives it different strides // @@ -691,8 +653,7 @@ fn test_cow_shrink() } #[test] -fn test_sub() -{ +fn test_sub() { let mat = Array::from_iter(0..16) .into_shape_with_order((2, 4, 2)) .unwrap(); @@ -712,8 +673,7 @@ fn test_sub() #[should_panic] #[test] -fn test_sub_oob_1() -{ +fn test_sub_oob_1() { let mat = Array::from_iter(0..16) .into_shape_with_order((2, 4, 2)) .unwrap(); @@ -722,8 +682,7 @@ fn test_sub_oob_1() #[test] #[cfg(feature = "approx")] -fn test_select() -{ +fn test_select() { use approx::assert_abs_diff_eq; // test for 2-d array @@ -746,8 +705,7 @@ fn test_select() } #[test] -fn test_select_1d() -{ +fn test_select_1d() { let x = arr1(&[0, 1, 2, 3, 4, 5, 6]); let r1 = x.select(Axis(0), &[1, 3, 4, 2, 2, 5]); assert_eq!(r1, arr1(&[1, 3, 4, 2, 2, 5])); @@ -760,8 +718,7 @@ fn test_select_1d() } #[test] -fn diag() -{ +fn diag() { let d = arr2(&[[1., 2., 3.0f32]]).into_diag(); assert_eq!(d.dim(), 1); let a = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]); @@ -778,8 +735,7 @@ fn diag() /// Note that this does not check the strides in the "merged" case! #[test] #[allow(clippy::cognitive_complexity)] -fn merge_axes() -{ +fn merge_axes() { macro_rules! assert_merged { ($arr:expr, $slice:expr, $take:expr, $into:expr) => { let mut v = $arr.slice($slice); @@ -867,8 +823,7 @@ fn merge_axes() } #[test] -fn swapaxes() -{ +fn swapaxes() { let mut a = arr2(&[[1., 2.], [3., 4.0f32]]); let b = arr2(&[[1., 3.], [2., 4.0f32]]); assert!(a != b); @@ -881,8 +836,7 @@ fn swapaxes() } #[test] -fn permuted_axes() -{ +fn permuted_axes() { let a = array![1].index_axis_move(Axis(0), 0); let permuted = a.view().permuted_axes([]); assert_eq!(a, permuted); @@ -918,8 +872,7 @@ fn permuted_axes() #[should_panic] #[test] -fn permuted_axes_repeated_axis() -{ +fn permuted_axes_repeated_axis() { let a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap(); @@ -928,8 +881,7 @@ fn permuted_axes_repeated_axis() #[should_panic] #[test] -fn permuted_axes_missing_axis() -{ +fn permuted_axes_missing_axis() { let a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap() @@ -939,8 +891,7 @@ fn permuted_axes_missing_axis() #[should_panic] #[test] -fn permuted_axes_oob() -{ +fn permuted_axes_oob() { let a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap(); @@ -948,8 +899,7 @@ fn permuted_axes_oob() } #[test] -fn standard_layout() -{ +fn standard_layout() { let mut a = arr2(&[[1., 2.], [3., 4.0]]); assert!(a.is_standard_layout()); a.swap_axes(0, 1); @@ -967,8 +917,7 @@ fn standard_layout() } #[test] -fn iter_size_hint() -{ +fn iter_size_hint() { let mut a = arr2(&[[1., 2.], [3., 4.]]); { let mut it = a.iter(); @@ -1003,8 +952,7 @@ fn iter_size_hint() } #[test] -fn zero_axes() -{ +fn zero_axes() { let mut a = arr1::(&[]); if a.iter().next().is_some() { panic!(); @@ -1022,8 +970,7 @@ fn zero_axes() } #[test] -fn equality() -{ +fn equality() { let a = arr2(&[[1., 2.], [3., 4.]]); let mut b = arr2(&[[1., 2.], [2., 4.]]); assert!(a != b); @@ -1036,8 +983,7 @@ fn equality() } #[test] -fn map1() -{ +fn map1() { let a = arr2(&[[1., 2.], [3., 4.]]); let b = a.map(|&x| (x / 3.) as isize); assert_eq!(b, arr2(&[[0, 0], [1, 1]])); @@ -1047,24 +993,21 @@ fn map1() } #[test] -fn mapv_into_any_same_type() -{ +fn mapv_into_any_same_type() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one); } #[test] -fn mapv_into_any_diff_types() -{ +fn mapv_into_any_diff_types() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_even: Array = array![[false, true, false], [true, false, true]]; assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even); } #[test] -fn as_slice_memory_order_mut_arcarray() -{ +fn as_slice_memory_order_mut_arcarray() { // Test that mutation breaks sharing for `ArcArray`. let a = rcarr2(&[[1., 2.], [3., 4.0f32]]); let mut b = a.clone(); @@ -1075,8 +1018,7 @@ fn as_slice_memory_order_mut_arcarray() } #[test] -fn as_slice_memory_order_mut_cowarray() -{ +fn as_slice_memory_order_mut_cowarray() { // Test that mutation breaks sharing for `CowArray`. let a = arr2(&[[1., 2.], [3., 4.0f32]]); let mut b = CowArray::from(a.view()); @@ -1087,8 +1029,7 @@ fn as_slice_memory_order_mut_cowarray() } #[test] -fn as_slice_memory_order_mut_contiguous_arcarray() -{ +fn as_slice_memory_order_mut_contiguous_arcarray() { // Test that unsharing preserves the strides in the contiguous case for `ArcArray`. let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); let mut b = a.clone().slice_move(s![.., ..2]); @@ -1098,8 +1039,7 @@ fn as_slice_memory_order_mut_contiguous_arcarray() } #[test] -fn as_slice_memory_order_mut_contiguous_cowarray() -{ +fn as_slice_memory_order_mut_contiguous_cowarray() { // Test that unsharing preserves the strides in the contiguous case for `CowArray`. let a = arr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); let mut b = CowArray::from(a.slice(s![.., ..2])); @@ -1110,8 +1050,7 @@ fn as_slice_memory_order_mut_contiguous_cowarray() } #[test] -fn to_slice_memory_order() -{ +fn to_slice_memory_order() { for shape in [[2, 0, 3, 5], [2, 1, 3, 5], [2, 4, 3, 5]] { let data: Vec = (0..shape.iter().product()).collect(); let mut orig = Array1::from(data.clone()) @@ -1128,8 +1067,7 @@ fn to_slice_memory_order() } #[test] -fn to_slice_memory_order_discontiguous() -{ +fn to_slice_memory_order_discontiguous() { let mut orig = Array3::::zeros([3, 2, 4]); assert!(orig .slice(s![.., 1.., ..]) @@ -1150,8 +1088,7 @@ fn to_slice_memory_order_discontiguous() } #[test] -fn array0_into_scalar() -{ +fn array0_into_scalar() { // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); @@ -1169,8 +1106,7 @@ fn array0_into_scalar() } #[test] -fn array_view0_into_scalar() -{ +fn array_view0_into_scalar() { // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); @@ -1188,8 +1124,7 @@ fn array_view0_into_scalar() } #[test] -fn array_view_mut0_into_scalar() -{ +fn array_view_mut0_into_scalar() { // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); @@ -1204,8 +1139,7 @@ fn array_view_mut0_into_scalar() } #[test] -fn array1_into_raw_vec() -{ +fn array1_into_raw_vec() { let data = vec![4, 5, 6, 7]; let array = Array::from(data.clone()); let (raw_vec, offset) = array.into_raw_vec_and_offset(); @@ -1214,8 +1148,7 @@ fn array1_into_raw_vec() } #[test] -fn owned_array1() -{ +fn owned_array1() { let mut a = Array::from(vec![1, 2, 3, 4]); for elt in a.iter_mut() { *elt = 2; @@ -1240,8 +1173,7 @@ fn owned_array1() } #[test] -fn owned_array_with_stride() -{ +fn owned_array_with_stride() { let v: Vec<_> = (0..12).collect(); let dim = (2, 3, 2); let strides = (1, 4, 2); @@ -1251,8 +1183,7 @@ fn owned_array_with_stride() } #[test] -fn owned_array_discontiguous() -{ +fn owned_array_discontiguous() { use std::iter::repeat; let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect(); let dim = (3, 2, 2); @@ -1265,17 +1196,14 @@ fn owned_array_discontiguous() } #[test] -fn owned_array_discontiguous_drop() -{ +fn owned_array_discontiguous_drop() { use std::cell::RefCell; use std::collections::BTreeSet; use std::rc::Rc; struct InsertOnDrop(Rc>>, Option); - impl Drop for InsertOnDrop - { - fn drop(&mut self) - { + impl Drop for InsertOnDrop { + fn drop(&mut self) { let InsertOnDrop(ref set, ref mut value) = *self; set.borrow_mut().insert(value.take().expect("double drop!")); } @@ -1298,34 +1226,26 @@ macro_rules! assert_matches { ($value:expr, $pat:pat) => { match $value { $pat => {} - ref err => panic!( - "assertion failed: `{}` matches `{}` found: {:?}", - stringify!($value), - stringify!($pat), - err - ), + ref err => { + panic!("assertion failed: `{}` matches `{}` found: {:?}", stringify!($value), stringify!($pat), err) + } } }; } #[test] -fn from_vec_dim_stride_empty_1d() -{ +fn from_vec_dim_stride_empty_1d() { let empty: [f32; 0] = []; assert_matches!(Array::from_shape_vec(0.strides(1), empty.to_vec()), Ok(_)); } #[test] -fn from_vec_dim_stride_0d() -{ +fn from_vec_dim_stride_0d() { let empty: [f32; 0] = []; let one = [1.]; let two = [1., 2.]; // too few elements - assert_matches!( - Array::from_shape_vec(().strides(()), empty.to_vec()), - Err(_) - ); + assert_matches!(Array::from_shape_vec(().strides(()), empty.to_vec()), Err(_)); // exact number of elements assert_matches!(Array::from_shape_vec(().strides(()), one.to_vec()), Ok(_)); // too many are ok @@ -1333,8 +1253,7 @@ fn from_vec_dim_stride_0d() } #[test] -fn from_vec_dim_stride_2d_1() -{ +fn from_vec_dim_stride_2d_1() { let two = [1., 2.]; let d = Ix2(2, 1); let s = d.default_strides(); @@ -1342,8 +1261,7 @@ fn from_vec_dim_stride_2d_1() } #[test] -fn from_vec_dim_stride_2d_2() -{ +fn from_vec_dim_stride_2d_2() { let two = [1., 2.]; let d = Ix2(1, 2); let s = d.default_strides(); @@ -1351,44 +1269,31 @@ fn from_vec_dim_stride_2d_2() } #[test] -fn from_vec_dim_stride_2d_3() -{ +fn from_vec_dim_stride_2d_3() { let a = arr3(&[[[1]], [[2]], [[3]]]); let d = a.raw_dim(); let s = d.default_strides(); - assert_matches!( - Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), - Ok(_) - ); + assert_matches!(Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] -fn from_vec_dim_stride_2d_4() -{ +fn from_vec_dim_stride_2d_4() { let a = arr3(&[[[1]], [[2]], [[3]]]); let d = a.raw_dim(); let s = d.fortran_strides(); - assert_matches!( - Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), - Ok(_) - ); + assert_matches!(Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] -fn from_vec_dim_stride_2d_5() -{ +fn from_vec_dim_stride_2d_5() { let a = arr3(&[[[1, 2, 3]]]); let d = a.raw_dim(); let s = d.fortran_strides(); - assert_matches!( - Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), - Ok(_) - ); + assert_matches!(Array::from_shape_vec(d.strides(s), a.as_slice().unwrap().to_vec()), Ok(_)); } #[test] -fn from_vec_dim_stride_2d_6() -{ +fn from_vec_dim_stride_2d_6() { let a = [1., 2., 3., 4., 5., 6.]; let d = (2, 1, 1); let s = (2, 2, 1); @@ -1400,8 +1305,7 @@ fn from_vec_dim_stride_2d_6() } #[test] -fn from_vec_dim_stride_2d_7() -{ +fn from_vec_dim_stride_2d_7() { // empty arrays can have 0 strides let a: [f32; 0] = []; // [[]] shape=[4, 0], strides=[0, 1] @@ -1411,8 +1315,7 @@ fn from_vec_dim_stride_2d_7() } #[test] -fn from_vec_dim_stride_2d_8() -{ +fn from_vec_dim_stride_2d_8() { // strides of length 1 axes can be zero let a = [1.]; let d = (1, 1); @@ -1421,8 +1324,7 @@ fn from_vec_dim_stride_2d_8() } #[test] -fn from_vec_dim_stride_2d_rejects() -{ +fn from_vec_dim_stride_2d_rejects() { let two = [1., 2.]; let d = (2, 2); let s = (1, 0); @@ -1434,8 +1336,7 @@ fn from_vec_dim_stride_2d_rejects() } #[test] -fn views() -{ +fn views() { let a = ArcArray::from(vec![1, 2, 3, 4]) .into_shape_with_order((2, 2)) .unwrap(); @@ -1447,15 +1348,11 @@ fn views() a.clone()[(0, 0)] = 99; assert_eq!(b[(0, 0)], 1); - assert_eq!( - a.view().into_iter().cloned().collect::>(), - vec![1, 2, 3, 4] - ); + assert_eq!(a.view().into_iter().cloned().collect::>(), vec![1, 2, 3, 4]); } #[test] -fn view_mut() -{ +fn view_mut() { let mut a = ArcArray::from(vec![1, 2, 3, 4]) .into_shape_with_order((2, 2)) .unwrap(); @@ -1476,8 +1373,7 @@ fn view_mut() } #[test] -fn slice_mut() -{ +fn slice_mut() { let mut a = ArcArray::from(vec![1, 2, 3, 4]) .into_shape_with_order((2, 2)) .unwrap(); @@ -1501,8 +1397,7 @@ fn slice_mut() } #[test] -fn assign_ops() -{ +fn assign_ops() { let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = arr2(&[[1., 3.], [2., 4.]]); (*&mut a.view_mut()) += &b; @@ -1520,8 +1415,7 @@ fn assign_ops() } #[test] -fn aview() -{ +fn aview() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let data = [[1., 2., 3.], [4., 5., 6.]]; let b = aview2(&data); @@ -1530,8 +1424,7 @@ fn aview() } #[test] -fn aview_mut() -{ +fn aview_mut() { let mut data = [0; 16]; { let mut a = aview_mut1(&mut data).into_shape_with_order((4, 4)).unwrap(); @@ -1544,8 +1437,7 @@ fn aview_mut() } #[test] -fn transpose_view() -{ +fn transpose_view() { let a = arr2(&[[1, 2], [3, 4]]); let at = a.view().reversed_axes(); assert_eq!(at, arr2(&[[1, 3], [2, 4]])); @@ -1556,8 +1448,7 @@ fn transpose_view() } #[test] -fn transpose_view_mut() -{ +fn transpose_view_mut() { let mut a = arr2(&[[1, 2], [3, 4]]); let mut at = a.view_mut().reversed_axes(); at[[0, 1]] = 5; @@ -1571,8 +1462,7 @@ fn transpose_view_mut() #[test] #[allow(clippy::cognitive_complexity)] -fn insert_axis() -{ +fn insert_axis() { defmac!(test_insert orig, index, new => { let res = orig.insert_axis(Axis(index)); assert_eq!(res, new); @@ -1587,155 +1477,63 @@ fn insert_axis() test_insert!(arr1(&[1, 2, 3]), 1, arr2(&[[1], [2], [3]])); assert!(::std::panic::catch_unwind(|| arr1(&[1, 2, 3]).insert_axis(Axis(2))).is_err()); - test_insert!( - arr2(&[[1, 2, 3], [4, 5, 6]]), - 0, - arr3(&[[[1, 2, 3], [4, 5, 6]]]) - ); - test_insert!( - arr2(&[[1, 2, 3], [4, 5, 6]]), - 1, - arr3(&[[[1, 2, 3]], [[4, 5, 6]]]) - ); - test_insert!( - arr2(&[[1, 2, 3], [4, 5, 6]]), - 2, - arr3(&[[[1], [2], [3]], [[4], [5], [6]]]) - ); - assert!( - ::std::panic::catch_unwind(|| arr2(&[[1, 2, 3], [4, 5, 6]]).insert_axis(Axis(3))).is_err() - ); + test_insert!(arr2(&[[1, 2, 3], [4, 5, 6]]), 0, arr3(&[[[1, 2, 3], [4, 5, 6]]])); + test_insert!(arr2(&[[1, 2, 3], [4, 5, 6]]), 1, arr3(&[[[1, 2, 3]], [[4, 5, 6]]])); + test_insert!(arr2(&[[1, 2, 3], [4, 5, 6]]), 2, arr3(&[[[1], [2], [3]], [[4], [5], [6]]])); + assert!(::std::panic::catch_unwind(|| arr2(&[[1, 2, 3], [4, 5, 6]]).insert_axis(Axis(3))).is_err()); - test_insert!( - Array3::::zeros((3, 4, 5)), - 0, - Array4::::zeros((1, 3, 4, 5)) - ); - test_insert!( - Array3::::zeros((3, 4, 5)), - 1, - Array4::::zeros((3, 1, 4, 5)) - ); - test_insert!( - Array3::::zeros((3, 4, 5)), - 3, - Array4::::zeros((3, 4, 5, 1)) - ); - assert!( - ::std::panic::catch_unwind(|| Array3::::zeros((3, 4, 5)).insert_axis(Axis(4))).is_err() - ); + test_insert!(Array3::::zeros((3, 4, 5)), 0, Array4::::zeros((1, 3, 4, 5))); + test_insert!(Array3::::zeros((3, 4, 5)), 1, Array4::::zeros((3, 1, 4, 5))); + test_insert!(Array3::::zeros((3, 4, 5)), 3, Array4::::zeros((3, 4, 5, 1))); + assert!(::std::panic::catch_unwind(|| Array3::::zeros((3, 4, 5)).insert_axis(Axis(4))).is_err()); - test_insert!( - Array6::::zeros((2, 3, 4, 3, 2, 3)), - 0, - ArrayD::::zeros(vec![1, 2, 3, 4, 3, 2, 3]) - ); - test_insert!( - Array6::::zeros((2, 3, 4, 3, 2, 3)), - 3, - ArrayD::::zeros(vec![2, 3, 4, 1, 3, 2, 3]) - ); - test_insert!( - Array6::::zeros((2, 3, 4, 3, 2, 3)), - 6, - ArrayD::::zeros(vec![2, 3, 4, 3, 2, 3, 1]) - ); - assert!(::std::panic::catch_unwind( - || Array6::::zeros((2, 3, 4, 3, 2, 3)).insert_axis(Axis(7)) - ) - .is_err()); + test_insert!(Array6::::zeros((2, 3, 4, 3, 2, 3)), 0, ArrayD::::zeros(vec![1, 2, 3, 4, 3, 2, 3])); + test_insert!(Array6::::zeros((2, 3, 4, 3, 2, 3)), 3, ArrayD::::zeros(vec![2, 3, 4, 1, 3, 2, 3])); + test_insert!(Array6::::zeros((2, 3, 4, 3, 2, 3)), 6, ArrayD::::zeros(vec![2, 3, 4, 3, 2, 3, 1])); + assert!(::std::panic::catch_unwind(|| Array6::::zeros((2, 3, 4, 3, 2, 3)).insert_axis(Axis(7))).is_err()); - test_insert!( - ArrayD::::zeros(vec![3, 4, 5]), - 0, - ArrayD::::zeros(vec![1, 3, 4, 5]) - ); - test_insert!( - ArrayD::::zeros(vec![3, 4, 5]), - 1, - ArrayD::::zeros(vec![3, 1, 4, 5]) - ); - test_insert!( - ArrayD::::zeros(vec![3, 4, 5]), - 3, - ArrayD::::zeros(vec![3, 4, 5, 1]) - ); - assert!( - ::std::panic::catch_unwind(|| ArrayD::::zeros(vec![3, 4, 5]).insert_axis(Axis(4))) - .is_err() - ); + test_insert!(ArrayD::::zeros(vec![3, 4, 5]), 0, ArrayD::::zeros(vec![1, 3, 4, 5])); + test_insert!(ArrayD::::zeros(vec![3, 4, 5]), 1, ArrayD::::zeros(vec![3, 1, 4, 5])); + test_insert!(ArrayD::::zeros(vec![3, 4, 5]), 3, ArrayD::::zeros(vec![3, 4, 5, 1])); + assert!(::std::panic::catch_unwind(|| ArrayD::::zeros(vec![3, 4, 5]).insert_axis(Axis(4))).is_err()); } #[test] -fn insert_axis_f() -{ +fn insert_axis_f() { defmac!(test_insert_f orig, index, new => { let res = orig.insert_axis(Axis(index)); assert_eq!(res, new); assert!(res.t().is_standard_layout()); }); - test_insert_f!( - Array0::from_shape_vec(().f(), vec![1]).unwrap(), - 0, - arr1(&[1]) - ); - assert!( - ::std::panic::catch_unwind(|| Array0::from_shape_vec(().f(), vec![1]) - .unwrap() - .insert_axis(Axis(1))) - .is_err() - ); + test_insert_f!(Array0::from_shape_vec(().f(), vec![1]).unwrap(), 0, arr1(&[1])); + assert!(::std::panic::catch_unwind(|| Array0::from_shape_vec(().f(), vec![1]) + .unwrap() + .insert_axis(Axis(1))) + .is_err()); test_insert_f!(Array1::::zeros((3).f()), 0, Array2::::zeros((1, 3))); test_insert_f!(Array1::::zeros((3).f()), 1, Array2::::zeros((3, 1))); - assert!( - ::std::panic::catch_unwind(|| Array1::::zeros((3).f()).insert_axis(Axis(2))).is_err() - ); + assert!(::std::panic::catch_unwind(|| Array1::::zeros((3).f()).insert_axis(Axis(2))).is_err()); - test_insert_f!( - Array3::::zeros((3, 4, 5).f()), - 1, - Array4::::zeros((3, 1, 4, 5)) - ); - assert!( - ::std::panic::catch_unwind(|| Array3::::zeros((3, 4, 5).f()).insert_axis(Axis(4))) - .is_err() - ); + test_insert_f!(Array3::::zeros((3, 4, 5).f()), 1, Array4::::zeros((3, 1, 4, 5))); + assert!(::std::panic::catch_unwind(|| Array3::::zeros((3, 4, 5).f()).insert_axis(Axis(4))).is_err()); - test_insert_f!( - ArrayD::::zeros(vec![3, 4, 5].f()), - 1, - ArrayD::::zeros(vec![3, 1, 4, 5]) - ); - assert!(::std::panic::catch_unwind( - || ArrayD::::zeros(vec![3, 4, 5].f()).insert_axis(Axis(4)) - ) - .is_err()); + test_insert_f!(ArrayD::::zeros(vec![3, 4, 5].f()), 1, ArrayD::::zeros(vec![3, 1, 4, 5])); + assert!(::std::panic::catch_unwind(|| ArrayD::::zeros(vec![3, 4, 5].f()).insert_axis(Axis(4))).is_err()); } #[test] -fn insert_axis_view() -{ +fn insert_axis_view() { let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]; - assert_eq!( - a.index_axis(Axis(1), 0).insert_axis(Axis(0)), - array![[[1, 2], [5, 6], [9, 10]]] - ); - assert_eq!( - a.index_axis(Axis(1), 0).insert_axis(Axis(1)), - array![[[1, 2]], [[5, 6]], [[9, 10]]] - ); - assert_eq!( - a.index_axis(Axis(1), 0).insert_axis(Axis(2)), - array![[[1], [2]], [[5], [6]], [[9], [10]]] - ); + assert_eq!(a.index_axis(Axis(1), 0).insert_axis(Axis(0)), array![[[1, 2], [5, 6], [9, 10]]]); + assert_eq!(a.index_axis(Axis(1), 0).insert_axis(Axis(1)), array![[[1, 2]], [[5, 6]], [[9, 10]]]); + assert_eq!(a.index_axis(Axis(1), 0).insert_axis(Axis(2)), array![[[1], [2]], [[5], [6]], [[9], [10]]]); } #[test] -fn arithmetic_broadcast() -{ +fn arithmetic_broadcast() { let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = a.clone() * aview0(&1.); assert_eq!(a, b); @@ -1747,14 +1545,8 @@ fn arithmetic_broadcast() let a = arr2(&[[2], [3], [4]]); let b = arr1(&[5, 6, 7]); assert_eq!(&a + &b, arr2(&[[7, 8, 9], [8, 9, 10], [9, 10, 11]])); - assert_eq!( - a.clone() - &b, - arr2(&[[-3, -4, -5], [-2, -3, -4], [-1, -2, -3]]) - ); - assert_eq!( - a.clone() * b.clone(), - arr2(&[[10, 12, 14], [15, 18, 21], [20, 24, 28]]) - ); + assert_eq!(a.clone() - &b, arr2(&[[-3, -4, -5], [-2, -3, -4], [-1, -2, -3]])); + assert_eq!(a.clone() * b.clone(), arr2(&[[10, 12, 14], [15, 18, 21], [20, 24, 28]])); assert_eq!(&b / a, arr2(&[[2, 3, 3], [1, 2, 2], [1, 1, 1]])); // Negative strides and non-contiguous memory @@ -1765,14 +1557,8 @@ fn arithmetic_broadcast() let mut c = s.clone(); c.collapse_axis(Axis(2), 1); let c = c.slice(s![1,..;2,..]); - assert_eq!( - &a.to_owned() + &b, - arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]]) - ); - assert_eq!( - &a + b.into_owned() + c, - arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) - ); + assert_eq!(&a.to_owned() + &b, arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]])); + assert_eq!(&a + b.into_owned() + c, arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]])); // shared array let sa = a.to_shared(); @@ -1781,10 +1567,7 @@ fn arithmetic_broadcast() let sb2 = sb.to_shared(); let sc = c.to_shared(); let sc2 = sc.into_shared(); - assert_eq!( - sa2 + &sb2 + sc2.into_owned(), - arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) - ); + assert_eq!(sa2 + &sb2 + sc2.into_owned(), arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]])); // Same shape let a = s.slice(s![..;-1, ..;2, ..]); @@ -1794,8 +1577,7 @@ fn arithmetic_broadcast() } #[test] -fn char_array() -{ +fn char_array() { // test compilation & basics of non-numerical array let cc = ArcArray::from_iter("alphabet".chars()) .into_shape_with_order((4, 2)) @@ -1804,8 +1586,7 @@ fn char_array() } #[test] -fn scalar_ops() -{ +fn scalar_ops() { let a = Array::::zeros((5, 5)); let b = &a + 1; let c = (&a + &a + 2) - 3; @@ -1842,8 +1623,7 @@ fn scalar_ops() } #[test] -fn split_at() -{ +fn split_at() { let mut a = arr2(&[[1., 2.], [3., 4.]]); { @@ -1883,24 +1663,21 @@ fn split_at() #[test] #[should_panic] -fn deny_split_at_axis_out_of_bounds() -{ +fn deny_split_at_axis_out_of_bounds() { let a = arr2(&[[1., 2.], [3., 4.]]); a.view().split_at(Axis(2), 0); } #[test] #[should_panic] -fn deny_split_at_index_out_of_bounds() -{ +fn deny_split_at_index_out_of_bounds() { let a = arr2(&[[1., 2.], [3., 4.]]); a.view().split_at(Axis(1), 3); } #[test] #[cfg(feature = "std")] -fn test_range() -{ +fn test_range() { let a = Array::range(0., 5., 1.); assert_eq!(a.len(), 5); assert_eq!(a[0], 0.); @@ -1929,8 +1706,7 @@ fn test_range() } #[test] -fn test_f_order() -{ +fn test_f_order() { // Test that arrays are logically equal in every way, // even if the underlying memory order is different let c = arr2(&[[1, 2, 3], [4, 5, 6]]); @@ -1952,8 +1728,7 @@ fn test_f_order() } #[test] -fn to_owned_memory_order() -{ +fn to_owned_memory_order() { // check that .to_owned() makes f-contiguous arrays out of f-contiguous // input. let c = arr2(&[[1, 2, 3], [4, 5, 6]]); @@ -1973,8 +1748,7 @@ fn to_owned_memory_order() } #[test] -fn to_owned_neg_stride() -{ +fn to_owned_neg_stride() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); c.slice_collapse(s![.., ..;-1]); let co = c.to_owned(); @@ -1983,8 +1757,7 @@ fn to_owned_neg_stride() } #[test] -fn discontiguous_owned_to_owned() -{ +fn discontiguous_owned_to_owned() { let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); c.slice_collapse(s![.., ..;2]); @@ -1995,8 +1768,7 @@ fn discontiguous_owned_to_owned() } #[test] -fn map_memory_order() -{ +fn map_memory_order() { let a = arr3(&[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [0, -1, -2]]]); let mut v = a.view(); v.swap_axes(0, 1); @@ -2006,16 +1778,12 @@ fn map_memory_order() } #[test] -fn map_mut_with_unsharing() -{ +fn map_mut_with_unsharing() { // Fortran-layout `ArcArray`. let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); assert_eq!(a.shape(), &[2, 5]); assert_eq!(a.strides(), &[1, 2]); - assert_eq!( - a.as_slice_memory_order(), - Some(&[0, 5, 1, 6, 2, 7, 3, 8, 4, 9][..]) - ); + assert_eq!(a.as_slice_memory_order(), Some(&[0, 5, 1, 6, 2, 7, 3, 8, 4, 9][..])); // Shared reference of a portion of `a`. let mut b = a.clone().slice_move(s![.., ..2]); @@ -2034,8 +1802,7 @@ fn map_mut_with_unsharing() } #[test] -fn test_view_from_shape() -{ +fn test_view_from_shape() { let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let a = ArrayView::from_shape((2, 3, 2), &s).unwrap(); let mut answer = Array::from(s.to_vec()) @@ -2058,24 +1825,21 @@ fn test_view_from_shape() } #[test] -fn test_view_from_shape_allow_overlap() -{ +fn test_view_from_shape_allow_overlap() { let data = [0, 1, 2]; let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap(); assert_eq!(view, aview2(&[data; 2])); } #[test] -fn test_view_mut_from_shape_deny_overlap() -{ +fn test_view_mut_from_shape_deny_overlap() { let mut data = [0, 1, 2]; let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data); assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported)); } #[test] -fn test_contiguous() -{ +fn test_contiguous() { let c = arr3(&[[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [7, 7, 7]]]); assert!(c.is_standard_layout()); assert!(c.as_slice_memory_order().is_some()); @@ -2109,8 +1873,7 @@ fn test_contiguous() } #[test] -fn test_contiguous_single_element() -{ +fn test_contiguous_single_element() { assert_matches!(array![1].as_slice_memory_order(), Some(&[1])); let arr1 = array![1, 2, 3]; @@ -2125,21 +1888,14 @@ fn test_contiguous_single_element() } #[test] -fn test_contiguous_neg_strides() -{ +fn test_contiguous_neg_strides() { let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; let a = ArrayView::from_shape((2, 3, 2).strides((1, 4, 2)), &s).unwrap(); - assert_eq!( - a, - arr3(&[[[0, 2], [4, 6], [8, 10]], [[1, 3], [5, 7], [9, 11]]]) - ); + assert_eq!(a, arr3(&[[[0, 2], [4, 6], [8, 10]], [[1, 3], [5, 7], [9, 11]]])); assert!(a.as_slice_memory_order().is_some()); let mut b = a.slice(s![..;1, ..;-1, ..;-1]); - assert_eq!( - b, - arr3(&[[[10, 8], [6, 4], [2, 0]], [[11, 9], [7, 5], [3, 1]]]) - ); + assert_eq!(b, arr3(&[[[10, 8], [6, 4], [2, 0]], [[11, 9], [7, 5], [3, 1]]])); assert!(b.as_slice_memory_order().is_some()); b.swap_axes(1, 2); @@ -2151,10 +1907,7 @@ fn test_contiguous_neg_strides() assert!(b.as_slice_memory_order().is_some()); let mut c = b.reversed_axes(); - assert_eq!( - c, - arr3(&[[[11, 10], [9, 8]], [[7, 6], [5, 4]], [[3, 2], [1, 0]]]) - ); + assert_eq!(c, arr3(&[[[11, 10], [9, 8]], [[7, 6], [5, 4]], [[3, 2], [1, 0]]])); assert!(c.as_slice_memory_order().is_some()); c.merge_axes(Axis(1), Axis(2)); @@ -2184,8 +1937,7 @@ fn test_contiguous_neg_strides() } #[test] -fn test_swap() -{ +fn test_swap() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = a.clone(); @@ -2198,8 +1950,7 @@ fn test_swap() } #[test] -fn test_uswap() -{ +fn test_uswap() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = a.clone(); @@ -2212,8 +1963,7 @@ fn test_uswap() } #[test] -fn test_shape() -{ +fn test_shape() { let data = [0, 1, 2, 3, 4, 5]; let a = Array::from_shape_vec((1, 2, 3), data.to_vec()).unwrap(); let b = Array::from_shape_vec((1, 2, 3).f(), data.to_vec()).unwrap(); @@ -2227,8 +1977,7 @@ fn test_shape() } #[test] -fn test_view_from_shape_ptr() -{ +fn test_view_from_shape_ptr() { let data = [0, 1, 2, 3, 4, 5]; let view = unsafe { ArrayView::from_shape_ptr((2, 3), data.as_ptr()) }; assert_eq!(view, aview2(&[[0, 1, 2], [3, 4, 5]])); @@ -2244,8 +1993,7 @@ fn test_view_from_shape_ptr() #[should_panic(expected = "Unsupported")] #[cfg(debug_assertions)] #[test] -fn test_view_from_shape_ptr_deny_neg_strides() -{ +fn test_view_from_shape_ptr_deny_neg_strides() { let data = [0, 1, 2, 3, 4, 5]; let _view = unsafe { ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; } @@ -2253,8 +2001,7 @@ fn test_view_from_shape_ptr_deny_neg_strides() #[should_panic(expected = "Unsupported")] #[cfg(debug_assertions)] #[test] -fn test_view_mut_from_shape_ptr_deny_neg_strides() -{ +fn test_view_mut_from_shape_ptr_deny_neg_strides() { let mut data = [0, 1, 2, 3, 4, 5]; let _view = unsafe { ArrayViewMut::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_mut_ptr()) }; } @@ -2262,8 +2009,7 @@ fn test_view_mut_from_shape_ptr_deny_neg_strides() #[should_panic(expected = "Unsupported")] #[cfg(debug_assertions)] #[test] -fn test_raw_view_from_shape_ptr_deny_neg_strides() -{ +fn test_raw_view_from_shape_ptr_deny_neg_strides() { let data = [0, 1, 2, 3, 4, 5]; let _view = unsafe { RawArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; } @@ -2271,15 +2017,13 @@ fn test_raw_view_from_shape_ptr_deny_neg_strides() #[should_panic(expected = "Unsupported")] #[cfg(debug_assertions)] #[test] -fn test_raw_view_mut_from_shape_ptr_deny_neg_strides() -{ +fn test_raw_view_mut_from_shape_ptr_deny_neg_strides() { let mut data = [0, 1, 2, 3, 4, 5]; let _view = unsafe { RawArrayViewMut::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_mut_ptr()) }; } #[test] -fn test_raw_view_from_shape_allow_overlap() -{ +fn test_raw_view_from_shape_allow_overlap() { let data = [0, 1, 2]; let view; unsafe { @@ -2292,8 +2036,7 @@ fn test_raw_view_from_shape_allow_overlap() #[should_panic(expected = "strides must not allow any element")] #[cfg(debug_assertions)] #[test] -fn test_raw_view_mut_from_shape_deny_overlap() -{ +fn test_raw_view_mut_from_shape_deny_overlap() { let mut data = [0, 1, 2]; unsafe { RawArrayViewMut::from_shape_ptr((2, 3).strides((0, 1)), data.as_mut_ptr()); @@ -2301,8 +2044,7 @@ fn test_raw_view_mut_from_shape_deny_overlap() } #[test] -fn test_default() -{ +fn test_default() { let a = as Default>::default(); assert_eq!(a, aview2(&[[0.0; 0]; 0])); @@ -2313,16 +2055,14 @@ fn test_default() } #[test] -fn test_default_ixdyn() -{ +fn test_default_ixdyn() { let a = as Default>::default(); let b = >::zeros(IxDyn(&[0])); assert_eq!(a, b); } #[test] -fn test_map_axis() -{ +fn test_map_axis() { let a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); let b = a.map_axis(Axis(0), |view| view.sum()); @@ -2355,8 +2095,7 @@ fn test_map_axis() } #[test] -fn test_accumulate_axis_inplace_noop() -{ +fn test_accumulate_axis_inplace_noop() { let mut a = Array2::::zeros((0, 3)); a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); assert_eq!(a, Array2::zeros((0, 3))); @@ -2398,8 +2137,7 @@ fn test_accumulate_axis_inplace_nonstandard_layout() { } #[test] -fn test_to_vec() -{ +fn test_to_vec() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.slice_collapse(s![..;-1, ..]); @@ -2410,8 +2148,7 @@ fn test_to_vec() } #[test] -fn test_array_clone_unalias() -{ +fn test_array_clone_unalias() { let a = Array::::zeros((3, 3)); let mut b = a.clone(); b.fill(1); @@ -2420,8 +2157,7 @@ fn test_array_clone_unalias() } #[test] -fn test_array_clone_same_view() -{ +fn test_array_clone_same_view() { let mut a = Array::from_iter(0..9) .into_shape_with_order((3, 3)) .unwrap(); @@ -2431,8 +2167,7 @@ fn test_array_clone_same_view() } #[test] -fn test_array2_from_diag() -{ +fn test_array2_from_diag() { let diag = arr1(&[0, 1, 2]); let x = Array2::from_diag(&diag); let x_exp = arr2(&[[0, 0, 0], [0, 1, 0], [0, 0, 2]]); @@ -2446,8 +2181,7 @@ fn test_array2_from_diag() } #[test] -fn array_macros() -{ +fn array_macros() { // array let a1 = array![1, 2, 3]; assert_eq!(a1, arr1(&[1, 2, 3])); @@ -2459,10 +2193,7 @@ fn array_macros() assert_eq!(a4, arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])); let s = String::from("abc"); - let a2s = array![ - [String::from("w"), s], - [String::from("x"), String::from("y")] - ]; + let a2s = array![[String::from("w"), s], [String::from("x"), String::from("y")]]; assert_eq!(a2s[[0, 0]], "w"); assert_eq!(a2s[[0, 1]], "abc"); assert_eq!(a2s[[1, 0]], "x"); @@ -2475,8 +2206,7 @@ fn array_macros() } #[cfg(test)] -mod as_standard_layout_tests -{ +mod as_standard_layout_tests { use super::*; use ndarray::Data; use std::fmt::Debug; @@ -2495,8 +2225,7 @@ mod as_standard_layout_tests } #[test] - fn test_f_layout() - { + fn test_f_layout() { let shape = (2, 2).f(); let arr = Array::::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap(); assert!(!arr.is_standard_layout()); @@ -2504,16 +2233,14 @@ mod as_standard_layout_tests } #[test] - fn test_c_layout() - { + fn test_c_layout() { let arr = Array::::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); assert!(arr.is_standard_layout()); test_as_standard_layout_for(arr); } #[test] - fn test_f_layout_view() - { + fn test_f_layout_view() { let shape = (2, 2).f(); let arr = Array::::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap(); let arr_view = arr.view(); @@ -2522,8 +2249,7 @@ mod as_standard_layout_tests } #[test] - fn test_c_layout_view() - { + fn test_c_layout_view() { let arr = Array::::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); let arr_view = arr.view(); assert!(arr_view.is_standard_layout()); @@ -2531,16 +2257,14 @@ mod as_standard_layout_tests } #[test] - fn test_zero_dimensional_array() - { + fn test_zero_dimensional_array() { let arr_view = ArrayView1::::from(&[]); assert!(arr_view.is_standard_layout()); test_as_standard_layout_for(arr_view); } #[test] - fn test_custom_layout() - { + fn test_custom_layout() { let shape = (1, 2, 3, 2).strides((12, 1, 2, 6)); let arr_data: Vec = (0..12).collect(); let arr = Array::::from_shape_vec(shape, arr_data).unwrap(); @@ -2550,13 +2274,11 @@ mod as_standard_layout_tests } #[cfg(test)] -mod array_cow_tests -{ +mod array_cow_tests { use super::*; #[test] - fn test_is_variant() - { + fn test_is_variant() { let arr: Array = array![[1, 2], [3, 4]]; let arr_cow = CowArray::::from(arr.view()); assert!(arr_cow.is_view()); @@ -2566,8 +2288,7 @@ mod array_cow_tests assert!(!arr_cow.is_view()); } - fn run_with_various_layouts(mut f: impl FnMut(Array2)) - { + fn run_with_various_layouts(mut f: impl FnMut(Array2)) { for all in [ Array2::from_shape_vec((7, 8), (0..7 * 8).collect()).unwrap(), Array2::from_shape_vec((7, 8).f(), (0..7 * 8).collect()).unwrap(), @@ -2585,8 +2306,7 @@ mod array_cow_tests } #[test] - fn test_element_mutation() - { + fn test_element_mutation() { run_with_various_layouts(|arr: Array2| { let mut expected = arr.clone(); expected[(1, 1)] = 2; @@ -2606,8 +2326,7 @@ mod array_cow_tests } #[test] - fn test_clone() - { + fn test_clone() { run_with_various_layouts(|arr: Array2| { let arr_cow = CowArray::::from(arr.view()); let arr_cow_clone = arr_cow.clone(); @@ -2627,10 +2346,8 @@ mod array_cow_tests #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] - fn test_clone_from() - { - fn assert_eq_contents_and_layout(arr1: &CowArray<'_, i32, Ix2>, arr2: &CowArray<'_, i32, Ix2>) - { + fn test_clone_from() { + fn assert_eq_contents_and_layout(arr1: &CowArray<'_, i32, Ix2>, arr2: &CowArray<'_, i32, Ix2>) { assert_eq!(arr1, arr2); assert_eq!(arr1.dim(), arr2.dim()); assert_eq!(arr1.strides(), arr2.strides()); @@ -2666,8 +2383,7 @@ mod array_cow_tests } #[test] - fn test_into_owned() - { + fn test_into_owned() { run_with_various_layouts(|arr: Array2| { let before = CowArray::::from(arr.view()); let after = before.into_owned(); @@ -2683,79 +2399,57 @@ mod array_cow_tests } #[test] -fn test_remove_index() -{ +fn test_remove_index() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, - array![[1, 2], - [7, 8], - [10,11]]); + assert_eq!(a, array![[1, 2], [7, 8], [10, 11]]); let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.invert_axis(Axis(0)); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, - array![[10,11], - [4, 5], - [1, 2]]); + assert_eq!(a, array![[10, 11], [4, 5], [1, 2]]); a.remove_index(Axis(1), 1); assert_eq!(a.shape(), &[3, 1]); - assert_eq!(a, - array![[10], - [4], - [1]]); + assert_eq!(a, array![[10], [4], [1]]); a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, - array![[], - [], - []]); + assert_eq!(a, array![[], [], []]); } #[should_panic(expected = "must be less")] #[test] -fn test_remove_index_oob1() -{ +fn test_remove_index_oob1() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.remove_index(Axis(0), 4); } #[should_panic(expected = "must be less")] #[test] -fn test_remove_index_oob2() -{ +fn test_remove_index_oob2() { let mut a = array![[10], [4], [1]]; a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, - array![[], - [], - []]); + assert_eq!(a, array![[], [], []]); a.remove_index(Axis(0), 1); // ok - assert_eq!(a, - array![[], - []]); + assert_eq!(a, array![[], []]); a.remove_index(Axis(1), 0); // oob } #[should_panic(expected = "index out of bounds")] #[test] -fn test_remove_index_oob3() -{ +fn test_remove_index_oob3() { let mut a = array![[10], [4], [1]]; a.remove_index(Axis(2), 0); } #[test] -fn test_split_complex_view() -{ +fn test_split_complex_view() { let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::::new(i as f32 * j as f32, k as f32)); let Complex { re, im } = a.view().split_complex(); assert_relative_eq!(re.sum(), 90.); @@ -2763,8 +2457,7 @@ fn test_split_complex_view() } #[test] -fn test_split_complex_view_roundtrip() -{ +fn test_split_complex_view_roundtrip() { let a_re = Array3::from_shape_fn((3, 1, 5), |(i, j, _k)| i * j); let a_im = Array3::from_shape_fn((3, 1, 5), |(_i, _j, k)| k); let a = Array3::from_shape_fn((3, 1, 5), |(i, j, k)| Complex::new(a_re[[i, j, k]], a_im[[i, j, k]])); @@ -2774,8 +2467,7 @@ fn test_split_complex_view_roundtrip() } #[test] -fn test_split_complex_view_mut() -{ +fn test_split_complex_view_mut() { let eye_scalar = Array2::::eye(4); let eye_complex = Array2::>::eye(4); let mut a = Array2::>::zeros((4, 4)); @@ -2786,8 +2478,7 @@ fn test_split_complex_view_mut() } #[test] -fn test_split_complex_zerod() -{ +fn test_split_complex_zerod() { let mut a = Array0::from_elem((), Complex::new(42, 32)); let Complex { re, im } = a.view().split_complex(); assert_eq!(re.get(()), Some(&42)); @@ -2798,18 +2489,16 @@ fn test_split_complex_zerod() } #[test] -fn test_split_complex_permuted() -{ +fn test_split_complex_permuted() { let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::new(i * k + j, k)); let permuted = a.view().permuted_axes([1, 0, 2]); let Complex { re, im } = permuted.split_complex(); - assert_eq!(re.get((3,2,4)).unwrap(), &11); - assert_eq!(im.get((3,2,4)).unwrap(), &4); + assert_eq!(re.get((3, 2, 4)).unwrap(), &11); + assert_eq!(im.get((3, 2, 4)).unwrap(), &4); } #[test] -fn test_split_complex_invert_axis() -{ +fn test_split_complex_invert_axis() { let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64)); a.invert_axis(Axis(1)); let cmplx = a.view().split_complex(); @@ -2818,16 +2507,14 @@ fn test_split_complex_invert_axis() } #[test] -fn test_slice_assign() -{ +fn test_slice_assign() { let mut a = array![0, 1, 2, 3, 4]; *a.slice_mut(s![1..3]) += 1; assert_eq!(a, array![0, 2, 3, 3, 4]); } #[test] -fn reverse_axes() -{ +fn reverse_axes() { let mut a = arr2(&[[1, 2], [3, 4]]); a.reverse_axes(); assert_eq!(a, arr2(&[[1, 3], [2, 4]])); @@ -2847,8 +2534,7 @@ fn reverse_axes() } #[test] -fn permute_axes() -{ +fn permute_axes() { let mut a = arr2(&[[1, 2], [3, 4]]); a.permute_axes([1, 0]); assert_eq!(a, arr2(&[[1, 3], [2, 4]])); @@ -2874,8 +2560,7 @@ fn permute_axes() #[should_panic] #[test] -fn permute_axes_repeated_axis() -{ +fn permute_axes_repeated_axis() { let mut a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap(); @@ -2884,8 +2569,7 @@ fn permute_axes_repeated_axis() #[should_panic] #[test] -fn permute_axes_missing_axis() -{ +fn permute_axes_missing_axis() { let mut a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap() @@ -2895,8 +2579,7 @@ fn permute_axes_missing_axis() #[should_panic] #[test] -fn permute_axes_oob() -{ +fn permute_axes_oob() { let mut a = Array::from_iter(0..24) .into_shape_with_order((2, 3, 4)) .unwrap(); diff --git a/tests/assign.rs b/tests/assign.rs index 29a6b851..8205828c 100644 --- a/tests/assign.rs +++ b/tests/assign.rs @@ -3,8 +3,7 @@ use ndarray::prelude::*; use std::sync::atomic::{AtomicUsize, Ordering}; #[test] -fn assign() -{ +fn assign() { let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = arr2(&[[1., 3.], [2., 4.]]); a.assign(&b); @@ -29,8 +28,7 @@ fn assign() } #[test] -fn assign_to() -{ +fn assign_to() { let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = arr2(&[[0., 3.], [2., 0.]]); b.assign_to(&mut a); @@ -38,8 +36,7 @@ fn assign_to() } #[test] -fn move_into_copy() -{ +fn move_into_copy() { let a = arr2(&[[1., 2.], [3., 4.]]); let acopy = a.clone(); let mut b = Array::uninit(a.dim()); @@ -56,8 +53,7 @@ fn move_into_copy() } #[test] -fn move_into_owned() -{ +fn move_into_owned() { // Test various memory layouts and holes while moving String elements. for &use_f_order in &[false, true] { for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { @@ -87,8 +83,7 @@ fn move_into_owned() } #[test] -fn move_into_slicing() -{ +fn move_into_slicing() { // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). for &use_f_order in &[false, true] { for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { @@ -122,8 +117,7 @@ fn move_into_slicing() } #[test] -fn move_into_diag() -{ +fn move_into_diag() { // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). for &use_f_order in &[false, true] { let counter = DropCounter::default(); @@ -148,8 +142,7 @@ fn move_into_diag() } #[test] -fn move_into_0dim() -{ +fn move_into_0dim() { // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). for &use_f_order in &[false, true] { let counter = DropCounter::default(); @@ -176,8 +169,7 @@ fn move_into_0dim() } #[test] -fn move_into_empty() -{ +fn move_into_empty() { // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). for &use_f_order in &[false, true] { let counter = DropCounter::default(); @@ -203,8 +195,7 @@ fn move_into_empty() } #[test] -fn move_into() -{ +fn move_into() { // Test various memory layouts and holes while moving String elements with move_into for &use_f_order in &[false, true] { for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { @@ -235,34 +226,28 @@ fn move_into() /// This counter can create elements, and then count and verify /// the number of which have actually been dropped again. #[derive(Default)] -struct DropCounter -{ +struct DropCounter { created: AtomicUsize, dropped: AtomicUsize, } struct Element<'a>(&'a AtomicUsize); -impl DropCounter -{ - fn created(&self) -> usize - { +impl DropCounter { + fn created(&self) -> usize { self.created.load(Ordering::Relaxed) } - fn dropped(&self) -> usize - { + fn dropped(&self) -> usize { self.dropped.load(Ordering::Relaxed) } - fn element(&self) -> Element<'_> - { + fn element(&self) -> Element<'_> { self.created.fetch_add(1, Ordering::Relaxed); Element(&self.dropped) } - fn assert_drop_count(&self) - { + fn assert_drop_count(&self) { assert_eq!( self.created(), self.dropped(), @@ -273,10 +258,8 @@ impl DropCounter } } -impl<'a> Drop for Element<'a> -{ - fn drop(&mut self) - { +impl<'a> Drop for Element<'a> { + fn drop(&mut self) { self.0.fetch_add(1, Ordering::Relaxed); } } diff --git a/tests/azip.rs b/tests/azip.rs index f3618cb3..cf60241c 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -10,8 +10,7 @@ use itertools::{assert_equal, cloned}; use std::mem::swap; #[test] -fn test_azip1() -{ +fn test_azip1() { let mut a = Array::zeros(62); let mut x = 0; azip!((a in &mut a) { *a = x; x += 1; }); @@ -19,8 +18,7 @@ fn test_azip1() } #[test] -fn test_azip2() -{ +fn test_azip2() { let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32); azip!((a in &mut a, &b in &b) *a = b); @@ -28,8 +26,7 @@ fn test_azip2() } #[test] -fn test_azip2_1() -{ +fn test_azip2_1() { let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32); let b = b.slice(s![..;-1, 3..]); @@ -38,8 +35,7 @@ fn test_azip2_1() } #[test] -fn test_azip2_3() -{ +fn test_azip2_3() { let mut b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32); let mut c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); let a = b.clone(); @@ -50,8 +46,7 @@ fn test_azip2_3() #[test] #[cfg(feature = "approx")] -fn test_zip_collect() -{ +fn test_zip_collect() { use approx::assert_abs_diff_eq; // test Zip::map_collect and that it preserves c/f layout. @@ -79,8 +74,7 @@ fn test_zip_collect() #[test] #[cfg(feature = "approx")] -fn test_zip_assign_into() -{ +fn test_zip_assign_into() { use approx::assert_abs_diff_eq; let mut a = Array::::zeros((5, 10)); @@ -94,8 +88,7 @@ fn test_zip_assign_into() #[test] #[cfg(feature = "approx")] -fn test_zip_assign_into_cell() -{ +fn test_zip_assign_into_cell() { use approx::assert_abs_diff_eq; use std::cell::Cell; @@ -110,37 +103,30 @@ fn test_zip_assign_into_cell() } #[test] -fn test_zip_collect_drop() -{ +fn test_zip_collect_drop() { use std::cell::RefCell; use std::panic; struct Recorddrop<'a>((usize, usize), &'a RefCell>); - impl Drop for Recorddrop<'_> - { - fn drop(&mut self) - { + impl Drop for Recorddrop<'_> { + fn drop(&mut self) { self.1.borrow_mut().push(self.0); } } #[derive(Copy, Clone)] - enum Config - { + enum Config { CC, CF, FF, } - impl Config - { - fn a_is_f(self) -> bool - { + impl Config { + fn a_is_f(self) -> bool { !matches!(self, Config::CC | Config::CF) } - fn b_is_f(self) -> bool - { + fn b_is_f(self) -> bool { !matches!(self, Config::CC) } } @@ -183,8 +169,7 @@ fn test_zip_collect_drop() } #[test] -fn test_azip_syntax_trailing_comma() -{ +fn test_azip_syntax_trailing_comma() { let mut b = Array::::zeros((5, 5)); let mut c = Array::::ones((5, 5)); let a = b.clone(); @@ -195,8 +180,7 @@ fn test_azip_syntax_trailing_comma() #[test] #[cfg(feature = "approx")] -fn test_azip2_sum() -{ +fn test_azip2_sum() { use approx::assert_abs_diff_eq; let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); @@ -210,8 +194,7 @@ fn test_azip2_sum() #[test] #[cfg(all(feature = "approx", feature = "std"))] -fn test_azip3_slices() -{ +fn test_azip3_slices() { use approx::assert_abs_diff_eq; let mut a = [0.; 32]; @@ -231,8 +214,7 @@ fn test_azip3_slices() #[test] #[cfg(feature = "approx")] -fn test_broadcast() -{ +fn test_broadcast() { use approx::assert_abs_diff_eq; let n = 16; @@ -257,8 +239,7 @@ fn test_broadcast() #[should_panic] #[test] -fn test_zip_dim_mismatch_1() -{ +fn test_zip_dim_mismatch_1() { let mut a = Array::zeros((5, 7)); let mut d = a.raw_dim(); d[0] += 1; @@ -270,8 +251,7 @@ fn test_zip_dim_mismatch_1() // Zip::from(A).and(B) // where A is F-contiguous and B contiguous but neither F nor C contiguous. #[test] -fn test_contiguous_but_not_c_or_f() -{ +fn test_contiguous_but_not_c_or_f() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -297,8 +277,7 @@ fn test_contiguous_but_not_c_or_f() } #[test] -fn test_clone() -{ +fn test_clone() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -317,8 +296,7 @@ fn test_clone() } #[test] -fn test_indices_0() -{ +fn test_indices_0() { let a1 = arr0(3); let mut count = 0; @@ -331,8 +309,7 @@ fn test_indices_0() } #[test] -fn test_indices_1() -{ +fn test_indices_1() { let mut a1 = Array::default(12); for (i, elt) in a1.indexed_iter_mut() { *elt = i; @@ -362,8 +339,7 @@ fn test_indices_1() } #[test] -fn test_indices_2() -{ +fn test_indices_2() { let mut a1 = Array::default((10, 12)); for (i, elt) in a1.indexed_iter_mut() { *elt = i; @@ -393,8 +369,7 @@ fn test_indices_2() } #[test] -fn test_indices_3() -{ +fn test_indices_3() { let mut a1 = Array::default((4, 5, 6)); for (i, elt) in a1.indexed_iter_mut() { *elt = i; @@ -424,8 +399,7 @@ fn test_indices_3() } #[test] -fn test_indices_split_1() -{ +fn test_indices_split_1() { for m in (0..4).chain(10..12) { for n in (0..4).chain(10..12) { let a1 = Array::::default((m, n)); @@ -457,8 +431,7 @@ fn test_indices_split_1() } #[test] -fn test_zip_all() -{ +fn test_zip_all() { let a = Array::::zeros(62); let b = Array::::ones(62); let mut c = Array::::ones(62); @@ -469,8 +442,7 @@ fn test_zip_all() } #[test] -fn test_zip_all_empty_array() -{ +fn test_zip_all_empty_array() { let a = Array::::zeros(0); let b = Array::::ones(0); assert!(Zip::from(&a).and(&b).all(|&_x, &_y| true)); @@ -478,8 +450,7 @@ fn test_zip_all_empty_array() } #[test] -fn test_azip9() -{ +fn test_azip9() { let mut a = Array::::zeros(62); let b = Array::from_shape_fn(a.dim(), |j| j as i32); let c = Array::from_shape_fn(a.dim(), |j| (j * 2) as i32); diff --git a/tests/broadcast.rs b/tests/broadcast.rs index eda9babf..7c2c58fd 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,8 +1,7 @@ use ndarray::prelude::*; #[test] -fn broadcast_1() -{ +fn broadcast_1() { let a_dim = Dim([2, 4, 2, 2]); let b_dim = Dim([2, 1, 2, 1]); let a = Array::from_iter(0..a_dim.size()) @@ -33,8 +32,7 @@ fn broadcast_1() } #[test] -fn test_add() -{ +fn test_add() { let a_dim = Dim([2, 4, 2, 2]); let b_dim = Dim([2, 1, 2, 1]); let mut a = Array::from_iter(0..a_dim.size()) @@ -50,8 +48,7 @@ fn test_add() #[test] #[should_panic] -fn test_add_incompat() -{ +fn test_add_incompat() { let a_dim = Dim([2, 4, 2, 2]); let mut a = Array::from_iter(0..a_dim.size()) .into_shape_with_order(a_dim) @@ -61,8 +58,7 @@ fn test_add_incompat() } #[test] -fn test_broadcast() -{ +fn test_broadcast() { let (_, n, k) = (16, 16, 16); let x1 = 1.; // b0 broadcast 1 -> n, k @@ -82,8 +78,7 @@ fn test_broadcast() } #[test] -fn test_broadcast_1d() -{ +fn test_broadcast_1d() { let n = 16; let x1 = 1.; // b0 broadcast 1 -> n diff --git a/tests/clone.rs b/tests/clone.rs index 4a7e50b8..e1914ba7 100644 --- a/tests/clone.rs +++ b/tests/clone.rs @@ -1,8 +1,7 @@ use ndarray::arr2; #[test] -fn test_clone_from() -{ +fn test_clone_from() { let a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = arr2(&[[7, 7, 7]]); let mut c = b.clone(); diff --git a/tests/complex.rs b/tests/complex.rs index 824e296a..e12c17ed 100644 --- a/tests/complex.rs +++ b/tests/complex.rs @@ -3,14 +3,12 @@ use ndarray::{arr1, arr2, Axis}; use num_complex::Complex; use num_traits::Num; -fn c(re: T, im: T) -> Complex -{ +fn c(re: T, im: T) -> Complex { Complex::new(re, im) } #[test] -fn complex_mat_mul() -{ +fn complex_mat_mul() { let a = arr2(&[[c(3., 4.), c(2., 0.)], [c(0., -2.), c(3., 0.)]]); let b = (&a * c(3., 0.)).map(|c| 5. * c / c.norm_sqr()); println!("{:>8.2}", b); @@ -18,8 +16,5 @@ fn complex_mat_mul() let r = a.dot(&e); println!("{}", a); assert_eq!(r, a); - assert_eq!( - a.mean_axis(Axis(0)).unwrap(), - arr1(&[c(1.5, 1.), c(2.5, 0.)]) - ); + assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)])); } diff --git a/tests/dimension.rs b/tests/dimension.rs index 53f204c6..7a0c912c 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -7,8 +7,7 @@ use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis}; use std::hash::{Hash, Hasher}; #[test] -fn insert_axis() -{ +fn insert_axis() { assert_eq!(Dim([]).insert_axis(Axis(0)), Dim([1])); assert_eq!(Dim([3]).insert_axis(Axis(0)), Dim([1, 3])); @@ -20,10 +19,7 @@ fn insert_axis() assert_eq!(Dim([2, 3, 4]).insert_axis(Axis(2)), Dim([2, 3, 1, 4])); - assert_eq!( - Dim([2, 3, 4, 5, 6, 7]).insert_axis(Axis(2)), - Dim(vec![2, 3, 1, 4, 5, 6, 7]) - ); + assert_eq!(Dim([2, 3, 4, 5, 6, 7]).insert_axis(Axis(2)), Dim(vec![2, 3, 1, 4, 5, 6, 7])); assert_eq!(Dim(vec![]).insert_axis(Axis(0)), Dim(vec![1])); @@ -31,19 +27,12 @@ fn insert_axis() assert_eq!(Dim(vec![2, 3]).insert_axis(Axis(1)), Dim(vec![2, 1, 3])); assert_eq!(Dim(vec![2, 3]).insert_axis(Axis(2)), Dim(vec![2, 3, 1])); - assert_eq!( - Dim(vec![2, 3, 4, 5, 6]).insert_axis(Axis(2)), - Dim(vec![2, 3, 1, 4, 5, 6]) - ); - assert_eq!( - Dim(vec![2, 3, 4, 5, 6, 7]).insert_axis(Axis(2)), - Dim(vec![2, 3, 1, 4, 5, 6, 7]) - ); + assert_eq!(Dim(vec![2, 3, 4, 5, 6]).insert_axis(Axis(2)), Dim(vec![2, 3, 1, 4, 5, 6])); + assert_eq!(Dim(vec![2, 3, 4, 5, 6, 7]).insert_axis(Axis(2)), Dim(vec![2, 3, 1, 4, 5, 6, 7])); } #[test] -fn remove_axis() -{ +fn remove_axis() { assert_eq!(Dim([3]).remove_axis(Axis(0)), Dim([])); assert_eq!(Dim([1, 2]).remove_axis(Axis(0)), Dim([2])); assert_eq!(Dim([4, 5, 6]).remove_axis(Axis(1)), Dim([4, 6])); @@ -65,8 +54,7 @@ fn remove_axis() #[test] #[allow(clippy::eq_op)] -fn dyn_dimension() -{ +fn dyn_dimension() { let a = arr2(&[[1., 2.], [3., 4.0]]) .into_shape_with_order(vec![2, 2]) .unwrap(); @@ -82,8 +70,7 @@ fn dyn_dimension() } #[test] -fn dyn_insert() -{ +fn dyn_insert() { let mut v = vec![2, 3, 4, 5]; let mut dim = Dim(v.clone()); defmac!(test_insert index => { @@ -102,8 +89,7 @@ fn dyn_insert() } #[test] -fn dyn_remove() -{ +fn dyn_remove() { let mut v = vec![1, 2, 3, 4, 5, 6, 7]; let mut dim = Dim(v.clone()); defmac!(test_remove index => { @@ -122,8 +108,7 @@ fn dyn_remove() } #[test] -fn fastest_varying_order() -{ +fn fastest_varying_order() { let strides = Dim([2, 8, 4, 1]); let order = strides._fastest_varying_stride_order(); assert_eq!(order.slice(), &[3, 0, 2, 1]); @@ -133,31 +118,18 @@ fn fastest_varying_order() assert_eq!(order.slice(), &[3, 0, 2, 1]); assert_eq!(Dim([1, 3])._fastest_varying_stride_order(), Dim([0, 1])); - assert_eq!( - Dim([1, -3isize as usize])._fastest_varying_stride_order(), - Dim([0, 1]) - ); + assert_eq!(Dim([1, -3isize as usize])._fastest_varying_stride_order(), Dim([0, 1])); assert_eq!(Dim([7, 2])._fastest_varying_stride_order(), Dim([1, 0])); - assert_eq!( - Dim([-7isize as usize, 2])._fastest_varying_stride_order(), - Dim([1, 0]) - ); - assert_eq!( - Dim([6, 1, 3])._fastest_varying_stride_order(), - Dim([1, 2, 0]) - ); - assert_eq!( - Dim([-6isize as usize, 1, -3isize as usize])._fastest_varying_stride_order(), - Dim([1, 2, 0]) - ); + assert_eq!(Dim([-7isize as usize, 2])._fastest_varying_stride_order(), Dim([1, 0])); + assert_eq!(Dim([6, 1, 3])._fastest_varying_stride_order(), Dim([1, 2, 0])); + assert_eq!(Dim([-6isize as usize, 1, -3isize as usize])._fastest_varying_stride_order(), Dim([1, 2, 0])); // it's important that it produces distinct indices. Prefer the stable order // where 0 is before 1 when they are equal. assert_eq!(Dim([2, 2])._fastest_varying_stride_order(), [0, 1]); assert_eq!(Dim([2, 2, 1])._fastest_varying_stride_order(), [2, 0, 1]); assert_eq!( - Dim([-2isize as usize, -2isize as usize, 3, 1, -2isize as usize]) - ._fastest_varying_stride_order(), + Dim([-2isize as usize, -2isize as usize, 3, 1, -2isize as usize])._fastest_varying_stride_order(), [3, 0, 1, 4, 2] ); } @@ -196,8 +168,7 @@ fn min_stride_axis() { */ #[test] -fn max_stride_axis() -{ +fn max_stride_axis() { let a = ArrayF32::zeros(10); assert_eq!(a.max_stride_axis(), Axis(0)); @@ -224,8 +195,7 @@ fn max_stride_axis() } #[test] -fn test_indexing() -{ +fn test_indexing() { let mut x = Dim([1, 2]); assert_eq!(x[0], 1); @@ -236,8 +206,7 @@ fn test_indexing() } #[test] -fn test_operations() -{ +fn test_operations() { let mut x = Dim([1, 2]); let mut y = Dim([1, 1]); @@ -254,10 +223,8 @@ fn test_operations() #[test] #[allow(clippy::cognitive_complexity)] -fn test_hash() -{ - fn calc_hash(value: &T) -> u64 - { +fn test_hash() { + fn calc_hash(value: &T) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); value.hash(&mut hasher); hasher.finish() @@ -292,10 +259,8 @@ fn test_hash() } #[test] -fn test_generic_operations() -{ - fn test_dim(d: &D) - { +fn test_generic_operations() { + fn test_dim(d: &D) { let mut x = d.clone(); x[0] += 1; assert_eq!(x[0], 3); @@ -309,10 +274,8 @@ fn test_generic_operations() } #[test] -fn test_array_view() -{ - fn test_dim(d: &D) - { +fn test_array_view() { + fn test_dim(d: &D) { assert_eq!(d.as_array_view().sum(), 7); assert_eq!(d.as_array_view().strides(), &[1]); } @@ -325,8 +288,7 @@ fn test_array_view() #[test] #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[allow(clippy::cognitive_complexity)] -fn test_all_ndindex() -{ +fn test_all_ndindex() { use ndarray::IntoDimension; macro_rules! ndindex { ($($i:expr),*) => { diff --git a/tests/format.rs b/tests/format.rs index 35909871..4b21fe39 100644 --- a/tests/format.rs +++ b/tests/format.rs @@ -2,8 +2,7 @@ use ndarray::prelude::*; use ndarray::rcarr1; #[test] -fn formatting() -{ +fn formatting() { let a = rcarr1::(&[1., 2., 3., 4.]); assert_eq!(format!("{}", a), "[1, 2, 3, 4]"); assert_eq!(format!("{:4}", a), "[ 1, 2, 3, 4]"); @@ -56,8 +55,7 @@ fn formatting() } #[test] -fn debug_format() -{ +fn debug_format() { let a = Array2::::zeros((3, 4)); assert_eq!( format!("{:?}", a), diff --git a/tests/higher_order_f.rs b/tests/higher_order_f.rs index 72245412..c567eb3e 100644 --- a/tests/higher_order_f.rs +++ b/tests/higher_order_f.rs @@ -2,8 +2,7 @@ use ndarray::prelude::*; #[test] #[should_panic] -fn test_fold_axis_oob() -{ +fn test_fold_axis_oob() { let a = arr2(&[[1., 2.], [3., 4.]]); a.fold_axis(Axis(2), 0., |x, y| x + y); } diff --git a/tests/indices.rs b/tests/indices.rs index a9414f9a..ca6ca988 100644 --- a/tests/indices.rs +++ b/tests/indices.rs @@ -3,8 +3,7 @@ use ndarray::prelude::*; use ndarray::Order; #[test] -fn test_ixdyn_index_iterate() -{ +fn test_ixdyn_index_iterate() { for &order in &[Order::C, Order::F] { let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); diff --git a/tests/into-ixdyn.rs b/tests/into-ixdyn.rs index 410ce92b..d53cd59f 100644 --- a/tests/into-ixdyn.rs +++ b/tests/into-ixdyn.rs @@ -5,14 +5,12 @@ use ndarray::prelude::*; #[test] -fn test_arr0_into_dyn() -{ +fn test_arr0_into_dyn() { assert!(arr0(1.234).into_dyn()[IxDyn(&[])] == 1.234); } #[test] -fn test_arr2_into_arrd_nonstandard_strides() -{ +fn test_arr2_into_arrd_nonstandard_strides() { let arr = Array2::from_shape_fn((12, 34).f(), |(i, j)| i * 34 + j).into_dyn(); let brr = ArrayD::from_shape_fn(vec![12, 34], |d| d[0] * 34 + d[1]); diff --git a/tests/iterator_chunks.rs b/tests/iterator_chunks.rs index d4648293..e34d3bc0 100644 --- a/tests/iterator_chunks.rs +++ b/tests/iterator_chunks.rs @@ -5,8 +5,7 @@ use ndarray::prelude::*; #[test] -fn chunks() -{ +fn chunks() { use ndarray::NdProducer; let a = Array1::from_iter(0..100) .into_shape_with_order((10, 10)) @@ -45,15 +44,13 @@ fn chunks() #[should_panic] #[test] -fn chunks_different_size_1() -{ +fn chunks_different_size_1() { let a = Array::::zeros(vec![2, 3]); a.exact_chunks(vec![2]); } #[test] -fn chunks_ok_size() -{ +fn chunks_ok_size() { let mut a = Array::::zeros(vec![2, 3]); a.fill(1.); let mut c = 0; @@ -67,15 +64,13 @@ fn chunks_ok_size() #[should_panic] #[test] -fn chunks_different_size_2() -{ +fn chunks_different_size_2() { let a = Array::::zeros(vec![2, 3]); a.exact_chunks(vec![2, 3, 4]); } #[test] -fn chunks_mut() -{ +fn chunks_mut() { let mut a = Array::zeros((7, 8)); for (i, mut chunk) in a.exact_chunks_mut((2, 3)).into_iter().enumerate() { chunk.fill(i); @@ -95,8 +90,7 @@ fn chunks_mut() #[should_panic] #[test] -fn chunks_different_size_3() -{ +fn chunks_different_size_3() { let mut a = Array::::zeros(vec![2, 3]); a.exact_chunks_mut(vec![2, 3, 4]); } diff --git a/tests/iterators.rs b/tests/iterators.rs index 96b0673a..729daa10 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -22,8 +22,7 @@ macro_rules! assert_panics { } #[test] -fn double_ended() -{ +fn double_ended() { let a = Array::from_iter(0..8); let mut it = a.iter().cloned(); assert_eq!(it.next(), Some(0)); @@ -35,8 +34,7 @@ fn double_ended() } #[test] -fn double_ended_rows() -{ +fn double_ended_rows() { let a = ArcArray::from_iter(0..8).into_shape_clone((4, 2)).unwrap(); let mut row_it = a.rows().into_iter(); assert_equal(row_it.next_back().unwrap(), &[6, 7]); @@ -57,8 +55,7 @@ fn double_ended_rows() } #[test] -fn iter_size_hint() -{ +fn iter_size_hint() { // Check that the size hint is correctly computed let a = ArcArray::from_iter(0..24) .into_shape_with_order((2, 3, 4)) @@ -79,8 +76,7 @@ fn iter_size_hint() #[test] #[cfg(feature = "std")] -fn indexed() -{ +fn indexed() { let a = Array::from_iter(0..8); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt as usize); @@ -99,8 +95,7 @@ fn indexed() } #[test] -fn as_slice() -{ +fn as_slice() { use ndarray::Data; fn assert_slice_correct(v: &ArrayBase) @@ -147,18 +142,12 @@ fn as_slice() let a = a.into_shape_with_order((8, 1)).unwrap(); assert_slice_correct(&a); let u = a.slice(s![..;2, ..]); - println!( - "u={:?}, shape={:?}, strides={:?}", - u, - u.shape(), - u.strides() - ); + println!("u={:?}, shape={:?}, strides={:?}", u, u.shape(), u.strides()); assert!(u.as_slice().is_none()); } #[test] -fn inner_iter() -{ +fn inner_iter() { let a = ArcArray::from_iter(0..12); let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], @@ -167,30 +156,35 @@ fn inner_iter() // [[6, 7], // [8, 9], // ... - assert_equal(a.rows(), vec![ + assert_equal( + a.rows(), + vec![ aview1(&[0, 1]), aview1(&[2, 3]), aview1(&[4, 5]), aview1(&[6, 7]), aview1(&[8, 9]), aview1(&[10, 11]), - ]); + ], + ); let mut b = ArcArray::zeros((2, 3, 2)); b.swap_axes(0, 2); b.assign(&a); - assert_equal(b.rows(), vec![ + assert_equal( + b.rows(), + vec![ aview1(&[0, 1]), aview1(&[2, 3]), aview1(&[4, 5]), aview1(&[6, 7]), aview1(&[8, 9]), aview1(&[10, 11]), - ]); + ], + ); } #[test] -fn inner_iter_corner_cases() -{ +fn inner_iter_corner_cases() { let a0 = ArcArray::::zeros(()); assert_equal(a0.rows(), vec![aview1(&[0])]); @@ -202,8 +196,7 @@ fn inner_iter_corner_cases() } #[test] -fn inner_iter_size_hint() -{ +fn inner_iter_size_hint() { // Check that the size hint is correctly computed let a = ArcArray::from_iter(0..24) .into_shape_with_order((2, 3, 4)) @@ -220,8 +213,7 @@ fn inner_iter_size_hint() #[allow(deprecated)] // into_outer_iter #[test] -fn outer_iter() -{ +fn outer_iter() { let a = ArcArray::from_iter(0..12); let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], @@ -271,8 +263,7 @@ fn outer_iter() } #[test] -fn axis_iter() -{ +fn axis_iter() { let a = ArcArray::from_iter(0..12); let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], @@ -281,16 +272,14 @@ fn axis_iter() // [[6, 7], // [8, 9], // ... - assert_equal(a.axis_iter(Axis(1)), vec![ - a.index_axis(Axis(1), 0), - a.index_axis(Axis(1), 1), - a.index_axis(Axis(1), 2), - ]); + assert_equal( + a.axis_iter(Axis(1)), + vec![a.index_axis(Axis(1), 0), a.index_axis(Axis(1), 1), a.index_axis(Axis(1), 2)], + ); } #[test] -fn axis_iter_split_at() -{ +fn axis_iter_split_at() { let a = Array::from_iter(0..5); let iter = a.axis_iter(Axis(0)); let all: Vec<_> = iter.clone().collect(); @@ -302,8 +291,7 @@ fn axis_iter_split_at() } #[test] -fn axis_iter_split_at_partially_consumed() -{ +fn axis_iter_split_at_partially_consumed() { let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); while iter.next().is_some() { @@ -317,8 +305,7 @@ fn axis_iter_split_at_partially_consumed() } #[test] -fn axis_iter_zip() -{ +fn axis_iter_zip() { let a = Array::from_iter(0..5); let iter = a.axis_iter(Axis(0)); let mut b = Array::zeros(5); @@ -327,8 +314,7 @@ fn axis_iter_zip() } #[test] -fn axis_iter_zip_partially_consumed() -{ +fn axis_iter_zip_partially_consumed() { let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); let mut consumed = 0; @@ -343,8 +329,7 @@ fn axis_iter_zip_partially_consumed() } #[test] -fn axis_iter_zip_partially_consumed_discontiguous() -{ +fn axis_iter_zip_partially_consumed_discontiguous() { let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); let mut consumed = 0; @@ -362,8 +347,7 @@ fn axis_iter_zip_partially_consumed_discontiguous() use ndarray::ArrayView1; #[test] -fn outer_iter_corner_cases() -{ +fn outer_iter_corner_cases() { let a2 = ArcArray::::zeros((0, 3)); assert_equal(a2.outer_iter(), Vec::>::new()); @@ -373,8 +357,7 @@ fn outer_iter_corner_cases() #[allow(deprecated)] #[test] -fn outer_iter_mut() -{ +fn outer_iter_mut() { let a = ArcArray::from_iter(0..12); let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], @@ -398,8 +381,7 @@ fn outer_iter_mut() } #[test] -fn axis_iter_mut() -{ +fn axis_iter_mut() { let a = ArcArray::from_iter(0..12); let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], @@ -419,36 +401,44 @@ fn axis_iter_mut() } #[test] -fn axis_chunks_iter() -{ +fn axis_chunks_iter() { let a = ArcArray::from_iter(0..24); let a = a.into_shape_with_order((2, 6, 2)).unwrap(); let it = a.axis_chunks_iter(Axis(1), 2); - assert_equal(it, vec![ + assert_equal( + it, + vec![ arr3(&[[[0, 1], [2, 3]], [[12, 13], [14, 15]]]), arr3(&[[[4, 5], [6, 7]], [[16, 17], [18, 19]]]), arr3(&[[[8, 9], [10, 11]], [[20, 21], [22, 23]]]), - ]); + ], + ); let a = ArcArray::from_iter(0..28); let a = a.into_shape_with_order((2, 7, 2)).unwrap(); let it = a.axis_chunks_iter(Axis(1), 2); - assert_equal(it, vec![ + assert_equal( + it, + vec![ arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[12, 13]], [[26, 27]]]), - ]); + ], + ); let it = a.axis_chunks_iter(Axis(1), 2).rev(); - assert_equal(it, vec![ + assert_equal( + it, + vec![ arr3(&[[[12, 13]], [[26, 27]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]), - ]); + ], + ); let it = a.axis_chunks_iter(Axis(1), 7); assert_equal(it, vec![a.view()]); @@ -458,8 +448,7 @@ fn axis_chunks_iter() } #[test] -fn axis_iter_mut_split_at() -{ +fn axis_iter_mut_split_at() { let mut a = Array::from_iter(0..5); let mut a_clone = a.clone(); let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect(); @@ -471,8 +460,7 @@ fn axis_iter_mut_split_at() } #[test] -fn axis_iter_mut_split_at_partially_consumed() -{ +fn axis_iter_mut_split_at_partially_consumed() { let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { for mid in 0..=(a.len() - consumed) { @@ -498,8 +486,7 @@ fn axis_iter_mut_split_at_partially_consumed() } #[test] -fn axis_iter_mut_zip() -{ +fn axis_iter_mut_zip() { let orig = Array::from_iter(0..5); let mut cloned = orig.clone(); let iter = cloned.axis_iter_mut(Axis(0)); @@ -513,8 +500,7 @@ fn axis_iter_mut_zip() } #[test] -fn axis_iter_mut_zip_partially_consumed() -{ +fn axis_iter_mut_zip_partially_consumed() { let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { let remaining = a.len() - consumed; @@ -529,8 +515,7 @@ fn axis_iter_mut_zip_partially_consumed() } #[test] -fn axis_iter_mut_zip_partially_consumed_discontiguous() -{ +fn axis_iter_mut_zip_partially_consumed_discontiguous() { let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { let remaining = a.len() - consumed; @@ -546,8 +531,7 @@ fn axis_iter_mut_zip_partially_consumed_discontiguous() } #[test] -fn axis_chunks_iter_corner_cases() -{ +fn axis_chunks_iter_corner_cases() { // examples provided by @bluss in PR #65 // these tests highlight corner cases of the axis_chunks_iter implementation // and enable checking if no pointer offsetting is out of bounds. However @@ -562,11 +546,7 @@ fn axis_chunks_iter_corner_cases() let it = a.axis_chunks_iter(Axis(0), 8); assert_equal(it, vec![a.view()]); let it = a.axis_chunks_iter(Axis(0), 3); - assert_equal(it, vec![ - array![[7], [6], [5]], - array![[4], [3], [2]], - array![[1], [0]], - ]); + assert_equal(it, vec![array![[7], [6], [5]], array![[4], [3], [2]], array![[1], [0]]]); let b = ArcArray::::zeros((8, 2)); let a = b.slice(s![1..;2,..]); @@ -578,8 +558,7 @@ fn axis_chunks_iter_corner_cases() } #[test] -fn axis_chunks_iter_zero_stride() -{ +fn axis_chunks_iter_zero_stride() { { // stride 0 case let b = Array::from(vec![0f32; 0]) @@ -615,22 +594,19 @@ fn axis_chunks_iter_zero_stride() #[should_panic] #[test] -fn axis_chunks_iter_zero_chunk_size() -{ +fn axis_chunks_iter_zero_chunk_size() { let a = Array::from_iter(0..5); a.axis_chunks_iter(Axis(0), 0); } #[test] -fn axis_chunks_iter_zero_axis_len() -{ +fn axis_chunks_iter_zero_axis_len() { let a = Array::from_iter(0..0); assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); } #[test] -fn axis_chunks_iter_split_at() -{ +fn axis_chunks_iter_split_at() { let mut a = Array2::::zeros((11, 3)); a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i); for source in &[ @@ -657,8 +633,7 @@ fn axis_chunks_iter_split_at() } #[test] -fn axis_chunks_iter_mut() -{ +fn axis_chunks_iter_mut() { let a = ArcArray::from_iter(0..24); let mut a = a.into_shape_with_order((2, 6, 2)).unwrap(); @@ -670,22 +645,19 @@ fn axis_chunks_iter_mut() #[should_panic] #[test] -fn axis_chunks_iter_mut_zero_chunk_size() -{ +fn axis_chunks_iter_mut_zero_chunk_size() { let mut a = Array::from_iter(0..5); a.axis_chunks_iter_mut(Axis(0), 0); } #[test] -fn axis_chunks_iter_mut_zero_axis_len() -{ +fn axis_chunks_iter_mut_zero_axis_len() { let mut a = Array::from_iter(0..0); assert!(a.axis_chunks_iter_mut(Axis(0), 5).next().is_none()); } #[test] -fn outer_iter_size_hint() -{ +fn outer_iter_size_hint() { // Check that the size hint is correctly computed let a = ArcArray::from_iter(0..24) .into_shape_with_order((4, 3, 2)) @@ -720,8 +692,7 @@ fn outer_iter_size_hint() } #[test] -fn outer_iter_split_at() -{ +fn outer_iter_split_at() { let a = ArcArray::from_iter(0..30) .into_shape_with_order((5, 3, 2)) .unwrap(); @@ -745,8 +716,7 @@ fn outer_iter_split_at() #[test] #[should_panic] -fn outer_iter_split_at_panics() -{ +fn outer_iter_split_at_panics() { let a = ArcArray::from_iter(0..30) .into_shape_with_order((5, 3, 2)) .unwrap(); @@ -756,8 +726,7 @@ fn outer_iter_split_at_panics() } #[test] -fn outer_iter_mut_split_at() -{ +fn outer_iter_mut_split_at() { let mut a = ArcArray::from_iter(0..30) .into_shape_with_order((5, 3, 2)) .unwrap(); @@ -779,8 +748,7 @@ fn outer_iter_mut_split_at() } #[test] -fn iterators_are_send_sync() -{ +fn iterators_are_send_sync() { // When the element type is Send + Sync, then the iterators and views // are too. fn _send_sync(_: &T) {} @@ -812,8 +780,7 @@ fn iterators_are_send_sync() #[test] #[allow(clippy::unnecessary_fold)] -fn test_fold() -{ +fn test_fold() { let mut a = Array2::::default((20, 20)); a += 1; let mut iter = a.iter(); @@ -826,8 +793,7 @@ fn test_fold() } #[test] -fn nth_back_examples() -{ +fn nth_back_examples() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); @@ -840,8 +806,7 @@ fn nth_back_examples() } #[test] -fn nth_back_zero_n() -{ +fn nth_back_zero_n() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter1 = a.iter(); @@ -853,8 +818,7 @@ fn nth_back_zero_n() } #[test] -fn nth_back_nonzero_n() -{ +fn nth_back_nonzero_n() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter1 = a.iter(); @@ -870,8 +834,7 @@ fn nth_back_nonzero_n() } #[test] -fn nth_back_past_end() -{ +fn nth_back_past_end() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter = a.iter(); @@ -880,8 +843,7 @@ fn nth_back_past_end() } #[test] -fn nth_back_partially_consumed() -{ +fn nth_back_partially_consumed() { let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter = a.iter(); @@ -899,8 +861,7 @@ fn nth_back_partially_consumed() } #[test] -fn test_rfold() -{ +fn test_rfold() { { let mut a = Array1::::default(256); a += 1; @@ -938,24 +899,19 @@ fn test_rfold() acc.push(*elt); acc }); - assert_eq!( - Array1::from(output), - Array::from_iter((1..10).rev().map(|i| i * 2)) - ); + assert_eq!(Array1::from(output), Array::from_iter((1..10).rev().map(|i| i * 2))); } } #[test] -fn test_into_iter() -{ +fn test_into_iter() { let a = Array1::from(vec![1, 2, 3, 4]); let v = a.into_iter().collect::>(); assert_eq!(v, [1, 2, 3, 4]); } #[test] -fn test_into_iter_2d() -{ +fn test_into_iter_2d() { let a = Array1::from(vec![1, 2, 3, 4]) .into_shape_with_order((2, 2)) .unwrap(); @@ -972,8 +928,7 @@ fn test_into_iter_2d() #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn test_into_iter_sliced() -{ +fn test_into_iter_sliced() { let (m, n) = (4, 5); let drops = Cell::new(0); @@ -1017,25 +972,20 @@ fn test_into_iter_sliced() /// /// Compares equal by its "represented value". #[derive(Clone, Debug)] -struct DropCount<'a> -{ +struct DropCount<'a> { value: i32, my_drops: usize, drops: &'a Cell, } -impl PartialEq for DropCount<'_> -{ - fn eq(&self, other: &Self) -> bool - { +impl PartialEq for DropCount<'_> { + fn eq(&self, other: &Self) -> bool { self.value == other.value } } -impl<'a> DropCount<'a> -{ - fn new(value: i32, drops: &'a Cell) -> Self - { +impl<'a> DropCount<'a> { + fn new(value: i32, drops: &'a Cell) -> Self { DropCount { value, my_drops: 0, @@ -1044,10 +994,8 @@ impl<'a> DropCount<'a> } } -impl Drop for DropCount<'_> -{ - fn drop(&mut self) - { +impl Drop for DropCount<'_> { + fn drop(&mut self) { assert_eq!(self.my_drops, 0); self.my_drops += 1; self.drops.set(self.drops.get() + 1); @@ -1055,13 +1003,11 @@ impl Drop for DropCount<'_> } #[test] -fn test_impl_iter_compiles() -{ +fn test_impl_iter_compiles() { // Requires that the iterators are covariant in the element type // base case: std - fn slice_iter_non_empty_indices<'a>(array: &'a Vec<&str>) -> impl Iterator + 'a - { + fn slice_iter_non_empty_indices<'a>(array: &'a Vec<&str>) -> impl Iterator + 'a { array .iter() .enumerate() @@ -1072,8 +1018,7 @@ fn test_impl_iter_compiles() let _ = slice_iter_non_empty_indices; // ndarray case - fn array_iter_non_empty_indices<'a>(array: &'a Array<&str, Ix1>) -> impl Iterator + 'a - { + fn array_iter_non_empty_indices<'a>(array: &'a Array<&str, Ix1>) -> impl Iterator + 'a { array .iter() .enumerate() diff --git a/tests/ix0.rs b/tests/ix0.rs index 319de39e..0598cb25 100644 --- a/tests/ix0.rs +++ b/tests/ix0.rs @@ -7,8 +7,7 @@ use ndarray::Ix0; use ndarray::ShapeBuilder; #[test] -fn test_ix0() -{ +fn test_ix0() { let mut a = Array::zeros(Ix0()); assert_eq!(a[()], 0.); a[()] = 1.; @@ -27,8 +26,7 @@ fn test_ix0() } #[test] -fn test_ix0_add() -{ +fn test_ix0_add() { let mut a = Array::zeros(Ix0()); a += 1.; assert_eq!(a[()], 1.); @@ -37,8 +35,7 @@ fn test_ix0_add() } #[test] -fn test_ix0_add_add() -{ +fn test_ix0_add_add() { let mut a = Array::zeros(Ix0()); a += 1.; let mut b = Array::zeros(Ix0()); @@ -48,8 +45,7 @@ fn test_ix0_add_add() } #[test] -fn test_ix0_add_broad() -{ +fn test_ix0_add_broad() { let mut b = Array::from(vec![5., 6.]); let mut a = Array::zeros(Ix0()); a += 1.; diff --git a/tests/ixdyn.rs b/tests/ixdyn.rs index 51797514..e108884a 100644 --- a/tests/ixdyn.rs +++ b/tests/ixdyn.rs @@ -9,8 +9,7 @@ use ndarray::Order; use ndarray::ShapeBuilder; #[test] -fn test_ixdyn() -{ +fn test_ixdyn() { // check that we can use fixed size arrays for indexing let mut a = Array::zeros(vec![2, 3, 4]); a[[1, 1, 1]] = 1.; @@ -19,8 +18,7 @@ fn test_ixdyn() #[should_panic] #[test] -fn test_ixdyn_wrong_dim() -{ +fn test_ixdyn_wrong_dim() { // check that we can use but it panics at runtime, if number of axes is wrong let mut a = Array::zeros(vec![2, 3, 4]); a[[1, 1, 1]] = 1.; @@ -29,8 +27,7 @@ fn test_ixdyn_wrong_dim() } #[test] -fn test_ixdyn_out_of_bounds() -{ +fn test_ixdyn_out_of_bounds() { // check that we are out of bounds let a = Array::::zeros(vec![2, 3, 4]); let res = a.get([0, 3, 0]); @@ -38,8 +35,7 @@ fn test_ixdyn_out_of_bounds() } #[test] -fn test_ixdyn_iterate() -{ +fn test_ixdyn_iterate() { for &order in &[Order::C, Order::F] { let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); @@ -59,8 +55,7 @@ fn test_ixdyn_iterate() } #[test] -fn test_ixdyn_index_iterate() -{ +fn test_ixdyn_index_iterate() { for &order in &[Order::C, Order::F] { let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); @@ -79,8 +74,7 @@ fn test_ixdyn_index_iterate() } #[test] -fn test_ixdyn_uget() -{ +fn test_ixdyn_uget() { // check that we are out of bounds let mut a = Array::::zeros(vec![2, 3, 4]); @@ -109,8 +103,7 @@ fn test_ixdyn_uget() } #[test] -fn test_0() -{ +fn test_0() { let mut a = Array::zeros(vec![]); let z = vec![].into_dimension(); assert_eq!(a[z.clone()], 0.); @@ -130,8 +123,7 @@ fn test_0() } #[test] -fn test_0_add() -{ +fn test_0_add() { let mut a = Array::zeros(vec![]); a += 1.; assert_eq!(a[[]], 1.); @@ -140,8 +132,7 @@ fn test_0_add() } #[test] -fn test_0_add_add() -{ +fn test_0_add_add() { let mut a = Array::zeros(vec![]); a += 1.; let mut b = Array::zeros(vec![]); @@ -151,8 +142,7 @@ fn test_0_add_add() } #[test] -fn test_0_add_broad() -{ +fn test_0_add_broad() { let mut b = Array::from(vec![5., 6.]); let mut a = Array::zeros(vec![]); a += 1.; @@ -162,8 +152,7 @@ fn test_0_add_broad() } #[test] -fn test_into_dimension() -{ +fn test_into_dimension() { use ndarray::{Ix0, Ix1, Ix2, IxDyn}; let a = Array::from_iter(0..42) diff --git a/tests/numeric.rs b/tests/numeric.rs index b82a3561..adf8ef99 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -7,39 +7,33 @@ use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Ax use std::f64; #[test] -fn test_mean_with_nan_values() -{ +fn test_mean_with_nan_values() { let a = array![f64::NAN, 1.]; assert!(a.mean().unwrap().is_nan()); } #[test] -fn test_mean_with_empty_array_of_floats() -{ +fn test_mean_with_empty_array_of_floats() { let a: Array1 = array![]; assert!(a.mean().is_none()); } #[test] -fn test_mean_with_array_of_floats() -{ +fn test_mean_with_array_of_floats() { let a: Array1 = array![ - 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, - 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, - 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, - 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, - 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, - 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, - 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, - 0.68043427 + 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, 0.22879429, 0.43997224, + 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, 0.78995021, 0.16060581, 0.64635041, 0.34876609, + 0.78543249, 0.19938356, 0.34429457, 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, + 0.79140914, 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, 0.19239188, + 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, 0.43599069, 0.77867775, 0.88267754, + 0.83003623, 0.67016118, 0.67547638, 0.65220036, 0.68043427 ]; let exact_mean = 0.5475494054; assert_abs_diff_eq!(a.mean().unwrap(), exact_mean); } #[test] -fn sum_mean_prod() -{ +fn sum_mean_prod() { let a: Array2 = arr2(&[[1., 2.], [3., 4.]]); assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); @@ -54,20 +48,13 @@ fn sum_mean_prod() } #[test] -fn sum_mean_prod_empty() -{ +fn sum_mean_prod_empty() { assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); assert_eq!(Array3::::ones((2, 0, 3)).product(), 1.); assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); assert_eq!(Array1::::ones(0).product_axis(Axis(0)), arr0(1.)); - assert_eq!( - Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), - Array::zeros((2, 3)), - ); - assert_eq!( - Array3::::ones((2, 0, 3)).product_axis(Axis(1)), - Array::ones((2, 3)), - ); + assert_eq!(Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), Array::zeros((2, 3)),); + assert_eq!(Array3::::ones((2, 0, 3)).product_axis(Axis(1)), Array::ones((2, 3)),); let a = Array1::::ones(0).mean_axis(Axis(0)); assert_eq!(a, None); let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); @@ -75,16 +62,14 @@ fn sum_mean_prod_empty() } #[test] -fn test_cumprod_1d() -{ +fn test_cumprod_1d() { let a = array![1, 2, 3, 4]; let result = a.cumprod(Axis(0)); assert_eq!(result, array![1, 2, 6, 24]); } #[test] -fn test_cumprod_2d() -{ +fn test_cumprod_2d() { let a = array![[1, 2], [3, 4]]; let result_axis0 = a.cumprod(Axis(0)); @@ -95,8 +80,7 @@ fn test_cumprod_2d() } #[test] -fn test_cumprod_3d() -{ +fn test_cumprod_3d() { let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; let result_axis0 = a.cumprod(Axis(0)); @@ -110,8 +94,7 @@ fn test_cumprod_3d() } #[test] -fn test_cumprod_empty() -{ +fn test_cumprod_empty() { // For 2D empty array let b: Array2 = Array2::zeros((0, 0)); let result_axis0 = b.cumprod(Axis(0)); @@ -121,8 +104,7 @@ fn test_cumprod_empty() } #[test] -fn test_cumprod_1_element() -{ +fn test_cumprod_1_element() { // For 1D array with one element let a = array![5]; let result = a.cumprod(Axis(0)); @@ -138,16 +120,14 @@ fn test_cumprod_1_element() #[test] #[should_panic(expected = "axis is out of bounds for array of dimension")] -fn test_cumprod_axis_out_of_bounds() -{ +fn test_cumprod_axis_out_of_bounds() { let a = array![[1, 2], [3, 4]]; let _result = a.cumprod(Axis(2)); } #[test] #[cfg(feature = "std")] -fn var() -{ +fn var() { let a = array![1., -4.32, 1.14, 0.32]; assert_abs_diff_eq!(a.var(0.), 5.049875, epsilon = 1e-8); } @@ -155,8 +135,7 @@ fn var() #[test] #[cfg(feature = "std")] #[should_panic] -fn var_negative_ddof() -{ +fn var_negative_ddof() { let a = array![1., 2., 3.]; a.var(-1.); } @@ -164,16 +143,14 @@ fn var_negative_ddof() #[test] #[cfg(feature = "std")] #[should_panic] -fn var_too_large_ddof() -{ +fn var_too_large_ddof() { let a = array![1., 2., 3.]; a.var(4.); } #[test] #[cfg(feature = "std")] -fn var_nan_ddof() -{ +fn var_nan_ddof() { let a = Array2::::zeros((2, 3)); let v = a.var(f64::NAN); assert!(v.is_nan()); @@ -181,16 +158,14 @@ fn var_nan_ddof() #[test] #[cfg(feature = "std")] -fn var_empty_arr() -{ +fn var_empty_arr() { let a: Array1 = array![]; assert!(a.var(0.0).is_nan()); } #[test] #[cfg(feature = "std")] -fn std() -{ +fn std() { let a = array![1., -4.32, 1.14, 0.32]; assert_abs_diff_eq!(a.std(0.), 2.24719, epsilon = 1e-5); } @@ -198,8 +173,7 @@ fn std() #[test] #[cfg(feature = "std")] #[should_panic] -fn std_negative_ddof() -{ +fn std_negative_ddof() { let a = array![1., 2., 3.]; a.std(-1.); } @@ -207,16 +181,14 @@ fn std_negative_ddof() #[test] #[cfg(feature = "std")] #[should_panic] -fn std_too_large_ddof() -{ +fn std_too_large_ddof() { let a = array![1., 2., 3.]; a.std(4.); } #[test] #[cfg(feature = "std")] -fn std_nan_ddof() -{ +fn std_nan_ddof() { let a = Array2::::zeros((2, 3)); let v = a.std(f64::NAN); assert!(v.is_nan()); @@ -224,8 +196,7 @@ fn std_nan_ddof() #[test] #[cfg(feature = "std")] -fn std_empty_arr() -{ +fn std_empty_arr() { let a: Array1 = array![]; assert!(a.std(0.0).is_nan()); } @@ -233,21 +204,12 @@ fn std_empty_arr() #[test] #[cfg(feature = "approx")] #[cfg(feature = "std")] -fn var_axis() -{ +fn var_axis() { use ndarray::{aview0, aview2}; let a = array![ - [ - [-9.76, -0.38, 1.59, 6.23], - [-8.57, -9.27, 5.76, 6.01], - [-9.54, 5.09, 3.21, 6.56], - ], - [ - [8.23, -9.63, 3.76, -3.48], - [-5.46, 5.86, -2.81, 1.35], - [-1.08, 4.66, 8.34, -0.73], - ], + [[-9.76, -0.38, 1.59, 6.23], [-8.57, -9.27, 5.76, 6.01], [-9.54, 5.09, 3.21, 6.56],], + [[8.23, -9.63, 3.76, -3.48], [-5.46, 5.86, -2.81, 1.35], [-1.08, 4.66, 8.34, -0.73],], ]; assert_abs_diff_eq!( a.var_axis(Axis(0), 1.5), @@ -268,19 +230,12 @@ fn var_axis() ); assert_abs_diff_eq!( a.var_axis(Axis(2), 2.3), - aview2(&[ - [79.64552941, 129.09663235, 95.98929412], - [109.64952941, 43.28758824, 36.27439706], - ]), + aview2(&[[79.64552941, 129.09663235, 95.98929412], [109.64952941, 43.28758824, 36.27439706],]), epsilon = 1e-8, ); let b = array![[1.1, 2.3, 4.7]]; - assert_abs_diff_eq!( - b.var_axis(Axis(0), 0.), - aview1(&[0., 0., 0.]), - epsilon = 1e-12 - ); + assert_abs_diff_eq!(b.var_axis(Axis(0), 0.), aview1(&[0., 0., 0.]), epsilon = 1e-12); assert_abs_diff_eq!(b.var_axis(Axis(1), 0.), aview1(&[2.24]), epsilon = 1e-12); let c = array![[], []]; @@ -293,8 +248,7 @@ fn var_axis() #[test] #[cfg(feature = "approx")] #[cfg(feature = "std")] -fn std_axis() -{ +fn std_axis() { use ndarray::aview2; let a = array![ @@ -320,32 +274,18 @@ fn std_axis() ); assert_abs_diff_eq!( a.std_axis(Axis(1), 1.7), - aview2(&[ - [0.42698655, 0.48139215, 0.36874991, 0.41458724], - [0.26769097, 0.18941435, 0.30555015, 0.35118674], - ]), + aview2(&[[0.42698655, 0.48139215, 0.36874991, 0.41458724], [0.26769097, 0.18941435, 0.30555015, 0.35118674],]), epsilon = 1e-8, ); assert_abs_diff_eq!( a.std_axis(Axis(2), 2.3), - aview2(&[ - [0.41117907, 0.37130425, 0.35332388], - [0.16905862, 0.25304841, 0.39978276], - ]), + aview2(&[[0.41117907, 0.37130425, 0.35332388], [0.16905862, 0.25304841, 0.39978276],]), epsilon = 1e-8, ); let b = array![[100000., 1., 0.01]]; - assert_abs_diff_eq!( - b.std_axis(Axis(0), 0.), - aview1(&[0., 0., 0.]), - epsilon = 1e-12, - ); - assert_abs_diff_eq!( - b.std_axis(Axis(1), 0.), - aview1(&[47_140.214_021_552_77]), - epsilon = 1e-6, - ); + assert_abs_diff_eq!(b.std_axis(Axis(0), 0.), aview1(&[0., 0., 0.]), epsilon = 1e-12,); + assert_abs_diff_eq!(b.std_axis(Axis(1), 0.), aview1(&[47_140.214_021_552_77]), epsilon = 1e-6,); let c = array![[], []]; assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[])); @@ -354,8 +294,7 @@ fn std_axis() #[test] #[should_panic] #[cfg(feature = "std")] -fn var_axis_negative_ddof() -{ +fn var_axis_negative_ddof() { let a = array![1., 2., 3.]; a.var_axis(Axis(0), -1.); } @@ -363,16 +302,14 @@ fn var_axis_negative_ddof() #[test] #[should_panic] #[cfg(feature = "std")] -fn var_axis_too_large_ddof() -{ +fn var_axis_too_large_ddof() { let a = array![1., 2., 3.]; a.var_axis(Axis(0), 4.); } #[test] #[cfg(feature = "std")] -fn var_axis_nan_ddof() -{ +fn var_axis_nan_ddof() { let a = Array2::::zeros((2, 3)); let v = a.var_axis(Axis(1), f64::NAN); assert_eq!(v.shape(), &[2]); @@ -381,8 +318,7 @@ fn var_axis_nan_ddof() #[test] #[cfg(feature = "std")] -fn var_axis_empty_axis() -{ +fn var_axis_empty_axis() { let a = Array2::::zeros((2, 0)); let v = a.var_axis(Axis(1), 0.); assert_eq!(v.shape(), &[2]); @@ -392,16 +328,14 @@ fn var_axis_empty_axis() #[test] #[should_panic] #[cfg(feature = "std")] -fn std_axis_bad_dof() -{ +fn std_axis_bad_dof() { let a = array![1., 2., 3.]; a.std_axis(Axis(0), 4.); } #[test] #[cfg(feature = "std")] -fn std_axis_empty_axis() -{ +fn std_axis_empty_axis() { let a = Array2::::zeros((2, 0)); let v = a.std_axis(Axis(1), 0.); assert_eq!(v.shape(), &[2]); @@ -409,69 +343,48 @@ fn std_axis_empty_axis() } #[test] -fn diff_1d_order1() -{ +fn diff_1d_order1() { let data = array![1.0, 2.0, 4.0, 7.0]; let expected = array![1.0, 2.0, 3.0]; assert_eq!(data.diff(1, Axis(0)), expected); } #[test] -fn diff_1d_order2() -{ +fn diff_1d_order2() { let data = array![1.0, 2.0, 4.0, 7.0]; - assert_eq!( - data.diff(2, Axis(0)), - data.diff(1, Axis(0)).diff(1, Axis(0)) - ); + assert_eq!(data.diff(2, Axis(0)), data.diff(1, Axis(0)).diff(1, Axis(0))); } #[test] -fn diff_1d_order3() -{ +fn diff_1d_order3() { let data = array![1.0, 2.0, 4.0, 7.0]; - assert_eq!( - data.diff(3, Axis(0)), - data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0)) - ); + assert_eq!(data.diff(3, Axis(0)), data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0))); } #[test] -fn diff_2d_order1_ax0() -{ - let data = array![ - [1.0, 2.0, 4.0, 7.0], - [1.0, 3.0, 6.0, 6.0], - [1.5, 3.5, 5.5, 5.5] - ]; +fn diff_2d_order1_ax0() { + let data = array![[1.0, 2.0, 4.0, 7.0], [1.0, 3.0, 6.0, 6.0], [1.5, 3.5, 5.5, 5.5]]; let expected = array![[0.0, 1.0, 2.0, -1.0], [0.5, 0.5, -0.5, -0.5]]; assert_eq!(data.diff(1, Axis(0)), expected); } #[test] -fn diff_2d_order1_ax1() -{ - let data = array![ - [1.0, 2.0, 4.0, 7.0], - [1.0, 3.0, 6.0, 6.0], - [1.5, 3.5, 5.5, 5.5] - ]; +fn diff_2d_order1_ax1() { + let data = array![[1.0, 2.0, 4.0, 7.0], [1.0, 3.0, 6.0, 6.0], [1.5, 3.5, 5.5, 5.5]]; let expected = array![[1.0, 2.0, 3.0], [2.0, 3.0, 0.0], [2.0, 2.0, 0.0]]; assert_eq!(data.diff(1, Axis(1)), expected); } #[test] #[should_panic] -fn diff_panic_n_too_big() -{ +fn diff_panic_n_too_big() { let data = array![1.0, 2.0, 4.0, 7.0]; data.diff(10, Axis(0)); } #[test] #[should_panic] -fn diff_panic_axis_out_of_bounds() -{ +fn diff_panic_axis_out_of_bounds() { let data = array![1, 2, 4, 7]; data.diff(1, Axis(2)); } diff --git a/tests/oper.rs b/tests/oper.rs index a6d7054b..c3da2d5d 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -14,8 +14,7 @@ use defmac::defmac; use num_traits::Num; use num_traits::Zero; -fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) -{ +fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) { let aa = CowArray::from(arr1(a)); let bb = CowArray::from(arr1(b)); let cc = CowArray::from(arr1(c)); @@ -33,7 +32,8 @@ fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) } fn test_oper_arr(op: &str, mut aa: CowArray, bb: CowArray, cc: CowArray) -where D: Dimension +where + D: Dimension, { match op { "+" => { @@ -70,8 +70,7 @@ where D: Dimension } #[test] -fn operations() -{ +fn operations() { test_oper("+", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[1.0, 3.0, 5.0, 7.0]); test_oper("-", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[1.0, 1.0, 1.0, 1.0]); test_oper("*", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[0.0, 2.0, 6.0, 12.0]); @@ -81,8 +80,7 @@ fn operations() } #[test] -fn scalar_operations() -{ +fn scalar_operations() { let a = arr0::(1.); let b = rcarr1::(&[1., 1.]); let c = rcarr2(&[[1., 1.], [1., 1.]]); @@ -128,8 +126,7 @@ where } #[test] -fn dot_product() -{ +fn dot_product() { let a = Array::from_iter((0..69).map(|x| x as f32)); let b = &a * 2. - 7.; let dot = 197846.; @@ -165,8 +162,7 @@ fn dot_product() // test that we can dot product with a broadcast array #[test] -fn dot_product_0() -{ +fn dot_product_0() { let a = Array::from_iter((0..69).map(|x| x as f32)); let x = 1.5; let b = aview0(&x); @@ -186,8 +182,7 @@ fn dot_product_0() } #[test] -fn dot_product_neg_stride() -{ +fn dot_product_neg_stride() { // test that we can dot with negative stride let a = Array::from_iter((0..69).map(|x| x as f32)); let b = &a * 2. - 7.; @@ -206,8 +201,7 @@ fn dot_product_neg_stride() } #[test] -fn fold_and_sum() -{ +fn fold_and_sum() { let a = Array::from_iter((0..128).map(|x| x as f32)) .into_shape_with_order((8, 16)) .unwrap(); @@ -248,8 +242,7 @@ fn fold_and_sum() } #[test] -fn product() -{ +fn product() { let step = (2. - 0.5) / 127.; let a = Array::from_iter((0..128).map(|i| 0.5 + step * (i as f64))) .into_shape_with_order((8, 16)) @@ -271,19 +264,16 @@ fn product() } } -fn range_mat(m: Ix, n: Ix) -> Array2 -{ +fn range_mat(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } #[cfg(feature = "approx")] -fn range1_mat64(m: Ix) -> Array1 -{ +fn range1_mat64(m: Ix) -> Array1 { ArrayBuilder::new(m).build() } -fn range_i32(m: Ix, n: Ix) -> Array2 -{ +fn range_i32(m: Ix, n: Ix) -> Array2 { ArrayBuilder::new((m, n)).build() } @@ -318,8 +308,7 @@ where } #[test] -fn mat_mul() -{ +fn mat_mul() { let (m, n, k) = (8, 8, 8); let a = range_mat::(m, n); let b = range_mat::(n, k); @@ -381,8 +370,7 @@ fn mat_mul() // Check that matrix multiplication of contiguous matrices returns a // matrix with the same order #[test] -fn mat_mul_order() -{ +fn mat_mul_order() { let (m, n, k) = (8, 8, 8); let a = range_mat::(m, n); let b = range_mat::(n, k); @@ -401,8 +389,7 @@ fn mat_mul_order() // test matrix multiplication shape mismatch #[test] #[should_panic] -fn mat_mul_shape_mismatch() -{ +fn mat_mul_shape_mismatch() { let (m, k, k2, n) = (8, 8, 9, 8); let a = range_mat::(m, k); let b = range_mat::(k2, n); @@ -412,8 +399,7 @@ fn mat_mul_shape_mismatch() // test matrix multiplication shape mismatch #[test] #[should_panic] -fn mat_mul_shape_mismatch_2() -{ +fn mat_mul_shape_mismatch_2() { let (m, k, k2, n) = (8, 8, 8, 8); let a = range_mat::(m, k); let b = range_mat::(k2, n); @@ -424,8 +410,7 @@ fn mat_mul_shape_mismatch_2() // Check that matrix multiplication // supports broadcast arrays. #[test] -fn mat_mul_broadcast() -{ +fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); let a = range_mat::(m, n); let x1 = 1.; @@ -444,8 +429,7 @@ fn mat_mul_broadcast() // Check that matrix multiplication supports reversed axes #[test] -fn mat_mul_rev() -{ +fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); let a = range_mat::(m, n); let b = range_mat::(n, k); @@ -461,8 +445,7 @@ fn mat_mul_rev() // Check that matrix multiplication supports arrays with zero rows or columns #[test] -fn mat_mut_zero_len() -{ +fn mat_mut_zero_len() { defmac!(mat_mul_zero_len range_mat_fn => { for n in 0..4 { for m in 0..4 { @@ -483,8 +466,7 @@ fn mat_mut_zero_len() } #[test] -fn scaled_add() -{ +fn scaled_add() { let a = range_mat(16, 15); let mut b = range_mat(16, 15); b.mapv_inplace(f32::exp); @@ -500,8 +482,7 @@ fn scaled_add() #[cfg(feature = "approx")] #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn scaled_add_2() -{ +fn scaled_add_2() { let beta = -2.3; let sizes = vec![ (4, 4, 1, 4), @@ -539,8 +520,7 @@ fn scaled_add_2() #[cfg(feature = "approx")] #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn scaled_add_3() -{ +fn scaled_add_3() { use approx::assert_relative_eq; use ndarray::{Slice, SliceInfo, SliceInfoElem}; use std::convert::TryFrom; @@ -567,10 +547,7 @@ fn scaled_add_3() let cslice: Vec = if n == 1 { vec![Slice::from(..).step_by(s2).into()] } else { - vec![ - Slice::from(..).step_by(s1).into(), - Slice::from(..).step_by(s2).into(), - ] + vec![Slice::from(..).step_by(s1).into(), Slice::from(..).step_by(s2).into()] }; let c = range_mat::(n, q).into_shape_with_order(cdim).unwrap(); @@ -592,8 +569,7 @@ fn scaled_add_3() #[cfg(feature = "approx")] #[cfg_attr(miri, ignore)] #[test] -fn gen_mat_mul() -{ +fn gen_mat_mul() { use core::f64; let alpha = -2.3; @@ -637,8 +613,7 @@ fn gen_mat_mul() // Test y = A x where A is f-order #[cfg(feature = "approx")] #[test] -fn gemm_64_1_f() -{ +fn gemm_64_1_f() { let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 @@ -650,8 +625,7 @@ fn gemm_64_1_f() } #[test] -fn gen_mat_mul_i32() -{ +fn gen_mat_mul_i32() { let alpha = -1; let beta = 2; let sizes = if cfg!(miri) { @@ -683,8 +657,7 @@ fn gen_mat_mul_i32() #[cfg(feature = "approx")] #[test] #[cfg_attr(miri, ignore)] // Takes too long -fn gen_mat_vec_mul() -{ +fn gen_mat_vec_mul() { use core::f64; use approx::assert_relative_eq; @@ -710,17 +683,7 @@ fn gen_mat_vec_mul() let alpha = -2.3; let beta = f64::consts::PI; - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; + let sizes = vec![(4, 4), (8, 8), (17, 15), (4, 17), (17, 3), (19, 18), (16, 17), (15, 16), (67, 63)]; // test different strides for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { @@ -752,8 +715,7 @@ fn gen_mat_vec_mul() #[cfg(feature = "approx")] #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn vec_mat_mul() -{ +fn vec_mat_mul() { use approx::assert_relative_eq; // simple, slow, correct (hopefully) mat mul @@ -774,17 +736,7 @@ fn vec_mat_mul() .unwrap() } - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; + let sizes = vec![(4, 4), (8, 8), (17, 15), (4, 17), (17, 3), (19, 18), (16, 17), (15, 16), (67, 63)]; // test different strides for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { @@ -813,52 +765,33 @@ fn vec_mat_mul() } #[test] -fn kron_square_f64() -{ +fn kron_square_f64() { let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]); let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]); assert_eq!( kron(&a, &b), - arr2(&[ - [0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 0.0] - ]), + arr2(&[[0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0]]), ); assert_eq!( kron(&b, &a), - arr2(&[ - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0] - ]), + arr2(&[[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]), ) } #[test] -fn kron_square_i64() -{ +fn kron_square_i64() { let a = arr2(&[[1, 0], [0, 1]]); let b = arr2(&[[0, 1], [1, 0]]); - assert_eq!( - kron(&a, &b), - arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]), - ); + assert_eq!(kron(&a, &b), arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]),); - assert_eq!( - kron(&b, &a), - arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), - ) + assert_eq!(kron(&b, &a), arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),) } #[test] -fn kron_i64() -{ +fn kron_i64() { let a = arr2(&[[1, 0]]); let b = arr2(&[[0, 1], [1, 0]]); let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]); diff --git a/tests/par_azip.rs b/tests/par_azip.rs index 7dd233e5..f4f3aa61 100644 --- a/tests/par_azip.rs +++ b/tests/par_azip.rs @@ -7,8 +7,7 @@ use ndarray::prelude::*; use std::sync::atomic::{AtomicUsize, Ordering}; #[test] -fn test_par_azip1() -{ +fn test_par_azip1() { let mut a = Array::zeros(62); let b = Array::from_elem(62, 42); par_azip!((a in &mut a) { *a = 42 }); @@ -16,8 +15,7 @@ fn test_par_azip1() } #[test] -fn test_par_azip2() -{ +fn test_par_azip2() { let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32); par_azip!((a in &mut a, &b in &b, ) *a = b ); @@ -26,8 +24,7 @@ fn test_par_azip2() #[test] #[cfg(feature = "approx")] -fn test_par_azip3() -{ +fn test_par_azip3() { use approx::assert_abs_diff_eq; let mut a = [0.; 32]; @@ -47,8 +44,7 @@ fn test_par_azip3() #[should_panic] #[test] -fn test_zip_dim_mismatch_1() -{ +fn test_zip_dim_mismatch_1() { let mut a = Array::zeros((5, 7)); let mut d = a.raw_dim(); d[0] += 1; @@ -57,8 +53,7 @@ fn test_zip_dim_mismatch_1() } #[test] -fn test_indices_1() -{ +fn test_indices_1() { let mut a1 = Array::default(12); for (i, elt) in a1.indexed_iter_mut() { *elt = i; @@ -73,8 +68,7 @@ fn test_indices_1() } #[test] -fn test_par_azip9() -{ +fn test_par_azip9() { let mut a = Array::::zeros(62); let b = Array::from_shape_fn(a.dim(), |j| j as i32); let c = Array::from_shape_fn(a.dim(), |j| (j * 2) as i32); diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 1b6b2b79..2cfb33e2 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -9,8 +9,7 @@ const CHUNK_SIZE: usize = 100; const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE; #[test] -fn test_axis_iter() -{ +fn test_axis_iter() { let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -23,8 +22,7 @@ fn test_axis_iter() #[test] #[cfg(feature = "approx")] -fn test_axis_iter_mut() -{ +fn test_axis_iter_mut() { use approx::assert_abs_diff_eq; let mut a = Array::linspace(0.0..=1.0f64, M * N) .into_shape_with_order((M, N)) @@ -38,8 +36,7 @@ fn test_axis_iter_mut() } #[test] -fn test_regular_iter() -{ +fn test_regular_iter() { let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -50,8 +47,7 @@ fn test_regular_iter() } #[test] -fn test_regular_iter_collect() -{ +fn test_regular_iter_collect() { let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -61,8 +57,7 @@ fn test_regular_iter_collect() } #[test] -fn test_axis_chunks_iter() -{ +fn test_axis_chunks_iter() { let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() { v.fill(i as _); @@ -79,8 +74,7 @@ fn test_axis_chunks_iter() #[test] #[cfg(feature = "approx")] -fn test_axis_chunks_iter_mut() -{ +fn test_axis_chunks_iter_mut() { use approx::assert_abs_diff_eq; let mut a = Array::linspace(0.0..=1.0f64, M * N) .into_shape_with_order((M, N)) diff --git a/tests/par_zip.rs b/tests/par_zip.rs index 9f10d9fd..f9780ceb 100644 --- a/tests/par_zip.rs +++ b/tests/par_zip.rs @@ -8,16 +8,14 @@ const M: usize = 1024 * 10; const N: usize = 100; #[test] -fn test_zip_1() -{ +fn test_zip_1() { let mut a = Array2::::zeros((M, N)); Zip::from(&mut a).par_for_each(|x| *x = x.exp()); } #[test] -fn test_zip_index_1() -{ +fn test_zip_index_1() { let mut a = Array2::default((10, 10)); Zip::indexed(&mut a).par_for_each(|i, x| { @@ -30,8 +28,7 @@ fn test_zip_index_1() } #[test] -fn test_zip_index_2() -{ +fn test_zip_index_2() { let mut a = Array2::default((M, N)); Zip::indexed(&mut a).par_for_each(|i, x| { @@ -44,8 +41,7 @@ fn test_zip_index_2() } #[test] -fn test_zip_index_3() -{ +fn test_zip_index_3() { let mut a = Array::default((1, 2, 1, 2, 3)); Zip::indexed(&mut a).par_for_each(|i, x| { @@ -58,8 +54,7 @@ fn test_zip_index_3() } #[test] -fn test_zip_index_4() -{ +fn test_zip_index_4() { let mut a = Array2::zeros((M, N)); let mut b = Array2::zeros((M, N)); @@ -80,8 +75,7 @@ fn test_zip_index_4() #[test] #[cfg(feature = "approx")] -fn test_zip_collect() -{ +fn test_zip_collect() { use approx::assert_abs_diff_eq; // test Zip::map_collect and that it preserves c/f layout. @@ -109,8 +103,7 @@ fn test_zip_collect() #[test] #[cfg(feature = "approx")] -fn test_zip_small_collect() -{ +fn test_zip_small_collect() { use approx::assert_abs_diff_eq; for m in 0..32 { @@ -125,8 +118,7 @@ fn test_zip_small_collect() assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); if m > 1 && n > 1 { - assert_eq!(a.strides(), b.strides(), - "Failure for {}x{} c/f: {:?}", m, n, is_f); + assert_eq!(a.strides(), b.strides(), "Failure for {}x{} c/f: {:?}", m, n, is_f); } } } @@ -136,8 +128,7 @@ fn test_zip_small_collect() #[test] #[cfg(feature = "approx")] -fn test_zip_assign_into() -{ +fn test_zip_assign_into() { use approx::assert_abs_diff_eq; let mut a = Array::::zeros((M, N)); diff --git a/tests/raw_views.rs b/tests/raw_views.rs index 929e969d..bb39547e 100644 --- a/tests/raw_views.rs +++ b/tests/raw_views.rs @@ -4,8 +4,7 @@ use ndarray::Zip; use std::cell::Cell; #[test] -fn raw_view_cast_cell() -{ +fn raw_view_cast_cell() { // Test .cast() by creating an ArrayView> let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); @@ -21,8 +20,7 @@ fn raw_view_cast_cell() } #[test] -fn raw_view_cast_reinterpret() -{ +fn raw_view_cast_reinterpret() { // Test .cast() by reinterpreting u16 as [u8; 2] let a = Array::from_shape_fn((5, 5).f(), |(i, j)| (i as u16) << 8 | j as u16); let answer = a.mapv(u16::to_ne_bytes); @@ -33,8 +31,7 @@ fn raw_view_cast_reinterpret() } #[test] -fn raw_view_cast_zst() -{ +fn raw_view_cast_zst() { struct Zst; let a = Array::<(), _>::default((250, 250)); @@ -45,16 +42,14 @@ fn raw_view_cast_zst() #[test] #[should_panic] -fn raw_view_invalid_size_cast() -{ +fn raw_view_invalid_size_cast() { let data = [0i32; 16]; ArrayView::from(&data[..]).raw_view().cast::(); } #[test] #[should_panic] -fn raw_view_mut_invalid_size_cast() -{ +fn raw_view_mut_invalid_size_cast() { let mut data = [0i32; 16]; ArrayViewMut::from(&mut data[..]) .raw_view_mut() @@ -62,8 +57,7 @@ fn raw_view_mut_invalid_size_cast() } #[test] -fn raw_view_misaligned() -{ +fn raw_view_misaligned() { let data: [u16; 2] = [0x0011, 0x2233]; let ptr: *const u16 = data.as_ptr(); unsafe { @@ -75,10 +69,8 @@ fn raw_view_misaligned() #[test] #[cfg(debug_assertions)] #[should_panic = "The pointer must be aligned."] -fn raw_view_deref_into_view_misaligned() -{ - fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> - { +fn raw_view_deref_into_view_misaligned() { + fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> { let ptr: *const u16 = data.as_ptr(); unsafe { let misaligned_ptr = (ptr as *const u8).add(1) as *const u16; @@ -93,10 +85,8 @@ fn raw_view_deref_into_view_misaligned() #[test] #[cfg(debug_assertions)] #[should_panic = "Unsupported"] -fn raw_view_negative_strides() -{ - fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> - { +fn raw_view_negative_strides() { + fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> { let ptr: *const u16 = data.as_ptr(); unsafe { let raw_view = RawArrayView::from_shape_ptr(1.strides((-1isize) as usize), ptr); diff --git a/tests/reserve.rs b/tests/reserve.rs index 10862001..2a7efc4b 100644 --- a/tests/reserve.rs +++ b/tests/reserve.rs @@ -1,13 +1,11 @@ use ndarray::prelude::*; -fn into_raw_vec_capacity(a: Array) -> usize -{ +fn into_raw_vec_capacity(a: Array) -> usize { a.into_raw_vec_and_offset().0.capacity() } #[test] -fn reserve_1d() -{ +fn reserve_1d() { let mut a = Array1::::zeros((4,)); a.reserve(Axis(0), 1000).unwrap(); assert_eq!(a.shape(), &[4]); @@ -15,8 +13,7 @@ fn reserve_1d() } #[test] -fn reserve_3d() -{ +fn reserve_3d() { let mut a = Array3::::zeros((0, 4, 8)); a.reserve(Axis(0), 10).unwrap(); assert_eq!(a.shape(), &[0, 4, 8]); @@ -24,23 +21,20 @@ fn reserve_3d() } #[test] -fn reserve_empty_3d() -{ +fn reserve_empty_3d() { let mut a = Array3::::zeros((0, 0, 0)); a.reserve(Axis(0), 10).unwrap(); } #[test] -fn reserve_3d_axis1() -{ +fn reserve_3d_axis1() { let mut a = Array3::::zeros((2, 4, 8)); a.reserve(Axis(1), 10).unwrap(); assert!(into_raw_vec_capacity(a) >= 2 * 8 * 10); } #[test] -fn reserve_3d_repeat() -{ +fn reserve_3d_repeat() { let mut a = Array3::::zeros((2, 4, 8)); a.reserve(Axis(1), 10).unwrap(); a.reserve(Axis(2), 30).unwrap(); @@ -48,8 +42,7 @@ fn reserve_3d_repeat() } #[test] -fn reserve_2d_with_data() -{ +fn reserve_2d_with_data() { let mut a = array![[1, 2], [3, 4], [5, 6]]; a.reserve(Axis(1), 100).unwrap(); assert_eq!(a, array![[1, 2], [3, 4], [5, 6]]); @@ -57,8 +50,7 @@ fn reserve_2d_with_data() } #[test] -fn reserve_2d_inverted_with_data() -{ +fn reserve_2d_inverted_with_data() { let mut a = array![[1, 2], [3, 4], [5, 6]]; a.invert_axis(Axis(1)); assert_eq!(a, array![[2, 1], [4, 3], [6, 5]]); diff --git a/tests/reshape.rs b/tests/reshape.rs index a13a5c05..d44107ba 100644 --- a/tests/reshape.rs +++ b/tests/reshape.rs @@ -5,8 +5,7 @@ use itertools::enumerate; use ndarray::Order; #[test] -fn reshape() -{ +fn reshape() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); let u = v.into_shape_with_order((3, 3)); @@ -22,8 +21,7 @@ fn reshape() #[test] #[should_panic(expected = "IncompatibleShape")] -fn reshape_error1() -{ +fn reshape_error1() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); let _u = v.into_shape_with_order((2, 5)).unwrap(); @@ -31,8 +29,7 @@ fn reshape_error1() #[test] #[should_panic(expected = "IncompatibleLayout")] -fn reshape_error2() -{ +fn reshape_error2() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); let mut u = v.into_shape_with_order((2, 2, 2)).unwrap(); @@ -41,8 +38,7 @@ fn reshape_error2() } #[test] -fn reshape_f() -{ +fn reshape_f() { let mut u = Array::zeros((3, 4).f()); for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { *elt = i as i32; @@ -67,8 +63,7 @@ fn reshape_f() } #[test] -fn to_shape_easy() -{ +fn to_shape_easy() { // 1D -> C -> C let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); @@ -107,8 +102,7 @@ fn to_shape_easy() } #[test] -fn to_shape_copy() -{ +fn to_shape_copy() { // 1D -> C -> F let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); @@ -131,20 +125,21 @@ fn to_shape_copy() } #[test] -fn to_shape_add_axis() -{ +fn to_shape_add_axis() { // 1D -> C -> C let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); - assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view()); + assert!(u + .to_shape(((1, 4, 2), Order::ColumnMajor)) + .unwrap() + .is_view()); } #[test] -fn to_shape_copy_stride() -{ +fn to_shape_copy_stride() { let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; let vs = v.slice(s![.., ..3]); let lin1 = vs.to_shape(6).unwrap(); @@ -157,8 +152,7 @@ fn to_shape_copy_stride() } #[test] -fn to_shape_zero_len() -{ +fn to_shape_zero_len() { let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; let vs = v.slice(s![.., ..0]); let lin1 = vs.to_shape(0).unwrap(); @@ -168,8 +162,7 @@ fn to_shape_zero_len() #[test] #[should_panic(expected = "IncompatibleShape")] -fn to_shape_error1() -{ +fn to_shape_error1() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); let _u = v.to_shape((2, 5)).unwrap(); @@ -177,8 +170,7 @@ fn to_shape_error1() #[test] #[should_panic(expected = "IncompatibleShape")] -fn to_shape_error2() -{ +fn to_shape_error2() { // overflow let data = [3, 4, 5, 6, 7, 8]; let v = aview1(&data); @@ -186,8 +178,7 @@ fn to_shape_error2() } #[test] -fn to_shape_discontig() -{ +fn to_shape_discontig() { for &create_order in &[Order::C, Order::F] { let a = Array::from_iter(0..64); let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); @@ -202,11 +193,21 @@ fn to_shape_discontig() let v1 = a1.to_shape(((4, 2, 4), order)).unwrap(); assert!(v1.is_view()); let v1 = a1.to_shape(((8, 4), order)).unwrap(); - assert_eq!(v1.is_view(), order == create_order && create_order == Order::C, - "failed for {:?}, {:?}", create_order, order); + assert_eq!( + v1.is_view(), + order == create_order && create_order == Order::C, + "failed for {:?}, {:?}", + create_order, + order + ); let v1 = a1.to_shape(((4, 8), order)).unwrap(); - assert_eq!(v1.is_view(), order == create_order && create_order == Order::F, - "failed for {:?}, {:?}", create_order, order); + assert_eq!( + v1.is_view(), + order == create_order && create_order == Order::F, + "failed for {:?}, {:?}", + create_order, + order + ); let v1 = a1.to_shape((32, order)).unwrap(); assert!(!v1.is_view()); } @@ -214,8 +215,7 @@ fn to_shape_discontig() } #[test] -fn to_shape_broadcast() -{ +fn to_shape_broadcast() { for &create_order in &[Order::C, Order::F] { let a = Array::from_iter(0..64); let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); @@ -225,13 +225,24 @@ fn to_shape_broadcast() for &order in &[Order::C, Order::F] { let v2 = v1.to_shape(((2, 2, 2, 2, 2, 2), order)).unwrap(); - assert_eq!(v2.strides(), match (create_order, order) { - (Order::C, Order::C) => { &[32, 16, 0, 0, 2, 1] } - (Order::C, Order::F) => { &[16, 32, 0, 0, 1, 2] } - (Order::F, Order::C) => { &[2, 1, 0, 0, 32, 16] } - (Order::F, Order::F) => { &[1, 2, 0, 0, 16, 32] } - _other => unreachable!() - }); + assert_eq!( + v2.strides(), + match (create_order, order) { + (Order::C, Order::C) => { + &[32, 16, 0, 0, 2, 1] + } + (Order::C, Order::F) => { + &[16, 32, 0, 0, 1, 2] + } + (Order::F, Order::C) => { + &[2, 1, 0, 0, 32, 16] + } + (Order::F, Order::F) => { + &[1, 2, 0, 0, 16, 32] + } + _other => unreachable!(), + } + ); let v2 = v1.to_shape(((4, 4, 4), order)).unwrap(); assert!(v2.is_view()); @@ -242,8 +253,7 @@ fn to_shape_broadcast() } #[test] -fn into_shape_with_order() -{ +fn into_shape_with_order() { // 1D -> C -> C let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); @@ -282,8 +292,7 @@ fn into_shape_with_order() } #[test] -fn into_shape_clone() -{ +fn into_shape_clone() { // 1D -> C -> C { let data = [1, 2, 3, 4, 5, 6, 7, 8]; diff --git a/tests/s.rs b/tests/s.rs index 27e009eb..d523ac6f 100644 --- a/tests/s.rs +++ b/tests/s.rs @@ -3,8 +3,7 @@ use ndarray::{s, Array}; #[test] -fn test_s() -{ +fn test_s() { let a = Array::::zeros((3, 4)); let vi = a.slice(s![1.., ..;2]); assert_eq!(vi.shape(), &[2, 2]); diff --git a/tests/stacking.rs b/tests/stacking.rs index bdfe478b..1be67e5e 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,24 +1,19 @@ use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; #[test] -fn concatenating() -{ +fn concatenating() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let c = concatenate![Axis(0), a, b]; - assert_eq!( - c, - arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) - ); + assert_eq!(c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); let d = concatenate![Axis(1), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; - assert_eq!(d, aview2(&[[2., 9.], - [2., 9.]])); + assert_eq!(d, aview2(&[[2., 9.], [2., 9.]])); let d = concatenate![Axis(0), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; assert_eq!(d, aview2(&[[2.], [2.], [9.], [9.]])); @@ -34,8 +29,7 @@ fn concatenating() } #[test] -fn stacking() -{ +fn stacking() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); diff --git a/tests/variance.rs b/tests/variance.rs index e72805ff..ed117b6a 100644 --- a/tests/variance.rs +++ b/tests/variance.rs @@ -1,13 +1,11 @@ use ndarray::{Array1, ArrayView1}; -fn arrayview_covariant<'a: 'b, 'b>(x: ArrayView1<'a, f64>) -> ArrayView1<'b, f64> -{ +fn arrayview_covariant<'a: 'b, 'b>(x: ArrayView1<'a, f64>) -> ArrayView1<'b, f64> { x } #[test] -fn test_covariance() -{ +fn test_covariance() { let x = Array1::zeros(2); let shorter_view = arrayview_covariant(x.view()); assert_eq!(shorter_view[0], 0.0); diff --git a/tests/views.rs b/tests/views.rs index 02970b1b..ecef72fe 100644 --- a/tests/views.rs +++ b/tests/views.rs @@ -2,8 +2,7 @@ use ndarray::prelude::*; use ndarray::Zip; #[test] -fn cell_view() -{ +fn cell_view() { let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); let answer = &a + 1.; diff --git a/tests/windows.rs b/tests/windows.rs index 7d0f3699..58357188 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -20,8 +20,7 @@ use ndarray::{arr3, Zip}; /// Test that verifies the `Windows` iterator panics on window sizes equal to zero. #[test] #[should_panic] -fn windows_iterator_zero_size() -{ +fn windows_iterator_zero_size() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -30,8 +29,7 @@ fn windows_iterator_zero_size() /// Test that verifies that no windows are yielded on oversized window sizes. #[test] -fn windows_iterator_oversized() -{ +fn windows_iterator_oversized() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -41,10 +39,11 @@ fn windows_iterator_oversized() /// Simple test for iterating 1d-arrays via `Windows`. #[test] -fn windows_iterator_1d() -{ +fn windows_iterator_1d() { let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); - itertools::assert_equal(a.windows(Dim(4)), vec![ + itertools::assert_equal( + a.windows(Dim(4)), + vec![ arr1(&[10, 11, 12, 13]), arr1(&[11, 12, 13, 14]), arr1(&[12, 13, 14, 15]), @@ -52,17 +51,19 @@ fn windows_iterator_1d() arr1(&[14, 15, 16, 17]), arr1(&[15, 16, 17, 18]), arr1(&[16, 17, 18, 19]), - ]); + ], + ); } /// Simple test for iterating 2d-arrays via `Windows`. #[test] -fn windows_iterator_2d() -{ +fn windows_iterator_2d() { let a = Array::from_iter(10..30) .into_shape_with_order((5, 4)) .unwrap(); - itertools::assert_equal(a.windows(Dim((3, 2))), vec![ + itertools::assert_equal( + a.windows(Dim((3, 2))), + vec![ arr2(&[[10, 11], [14, 15], [18, 19]]), arr2(&[[11, 12], [15, 16], [19, 20]]), arr2(&[[12, 13], [16, 17], [20, 21]]), @@ -72,17 +73,19 @@ fn windows_iterator_2d() arr2(&[[18, 19], [22, 23], [26, 27]]), arr2(&[[19, 20], [23, 24], [27, 28]]), arr2(&[[20, 21], [24, 25], [28, 29]]), - ]); + ], + ); } /// Simple test for iterating 3d-arrays via `Windows`. #[test] -fn windows_iterator_3d() -{ +fn windows_iterator_3d() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); - itertools::assert_equal(a.windows(Dim((2, 2, 2))), vec![ + itertools::assert_equal( + a.windows(Dim((2, 2, 2))), + vec![ arr3(&[[[10, 11], [13, 14]], [[19, 20], [22, 23]]]), arr3(&[[[11, 12], [14, 15]], [[20, 21], [23, 24]]]), arr3(&[[[13, 14], [16, 17]], [[22, 23], [25, 26]]]), @@ -91,14 +94,14 @@ fn windows_iterator_3d() arr3(&[[[20, 21], [23, 24]], [[29, 30], [32, 33]]]), arr3(&[[[22, 23], [25, 26]], [[31, 32], [34, 35]]]), arr3(&[[[23, 24], [26, 27]], [[32, 33], [35, 36]]]), - ]); + ], + ); } /// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero. #[test] #[should_panic] -fn windows_iterator_stride_axis_zero() -{ +fn windows_iterator_stride_axis_zero() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -107,8 +110,7 @@ fn windows_iterator_stride_axis_zero() /// Test that verifies that only first window is yielded when stride is oversized on every axis. #[test] -fn windows_iterator_only_one_valid_window_for_oversized_stride() -{ +fn windows_iterator_only_one_valid_window_for_oversized_stride() { let a = Array::from_iter(10..135) .into_shape_with_order((5, 5, 5)) .unwrap(); @@ -118,42 +120,42 @@ fn windows_iterator_only_one_valid_window_for_oversized_stride() /// Simple test for iterating 1d-arrays via `Windows` with stride. #[test] -fn windows_iterator_1d_with_stride() -{ +fn windows_iterator_1d_with_stride() { let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); - itertools::assert_equal(a.windows_with_stride(4, 2), vec![ - arr1(&[10, 11, 12, 13]), - arr1(&[12, 13, 14, 15]), - arr1(&[14, 15, 16, 17]), - arr1(&[16, 17, 18, 19]), - ]); + itertools::assert_equal( + a.windows_with_stride(4, 2), + vec![arr1(&[10, 11, 12, 13]), arr1(&[12, 13, 14, 15]), arr1(&[14, 15, 16, 17]), arr1(&[16, 17, 18, 19])], + ); } /// Simple test for iterating 2d-arrays via `Windows` with stride. #[test] -fn windows_iterator_2d_with_stride() -{ +fn windows_iterator_2d_with_stride() { let a = Array::from_iter(10..30) .into_shape_with_order((5, 4)) .unwrap(); - itertools::assert_equal(a.windows_with_stride((3, 2), (2, 1)), vec![ + itertools::assert_equal( + a.windows_with_stride((3, 2), (2, 1)), + vec![ arr2(&[[10, 11], [14, 15], [18, 19]]), arr2(&[[11, 12], [15, 16], [19, 20]]), arr2(&[[12, 13], [16, 17], [20, 21]]), arr2(&[[18, 19], [22, 23], [26, 27]]), arr2(&[[19, 20], [23, 24], [27, 28]]), arr2(&[[20, 21], [24, 25], [28, 29]]), - ]); + ], + ); } /// Simple test for iterating 3d-arrays via `Windows` with stride. #[test] -fn windows_iterator_3d_with_stride() -{ +fn windows_iterator_3d_with_stride() { let a = Array::from_iter(10..74) .into_shape_with_order((4, 4, 4)) .unwrap(); - itertools::assert_equal(a.windows_with_stride((2, 2, 2), (2, 2, 2)), vec![ + itertools::assert_equal( + a.windows_with_stride((2, 2, 2), (2, 2, 2)), + vec![ arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]), arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]), arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]), @@ -162,12 +164,12 @@ fn windows_iterator_3d_with_stride() arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]), arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]), arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]), - ]); + ], + ); } #[test] -fn test_window_zip() -{ +fn test_window_zip() { let a = Array::from_iter(0..64) .into_shape_with_order((4, 4, 4)) .unwrap(); @@ -192,8 +194,7 @@ fn test_window_zip() /// Test verifies that non existent Axis results in panic #[test] #[should_panic] -fn axis_windows_outofbound() -{ +fn axis_windows_outofbound() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -203,8 +204,7 @@ fn axis_windows_outofbound() /// Test verifies that zero sizes results in panic #[test] #[should_panic] -fn axis_windows_zero_size() -{ +fn axis_windows_zero_size() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -213,8 +213,7 @@ fn axis_windows_zero_size() /// Test verifies that over sized windows yield nothing #[test] -fn axis_windows_oversized() -{ +fn axis_windows_oversized() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -224,61 +223,58 @@ fn axis_windows_oversized() /// Simple test for iterating 1d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_1d() -{ +fn test_axis_windows_1d() { let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); - itertools::assert_equal(a.axis_windows(Axis(0), 5), vec![ + itertools::assert_equal( + a.axis_windows(Axis(0), 5), + vec![ arr1(&[10, 11, 12, 13, 14]), arr1(&[11, 12, 13, 14, 15]), arr1(&[12, 13, 14, 15, 16]), arr1(&[13, 14, 15, 16, 17]), arr1(&[14, 15, 16, 17, 18]), arr1(&[15, 16, 17, 18, 19]), - ]); + ], + ); } /// Simple test for iterating 2d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_2d() -{ +fn test_axis_windows_2d() { let a = Array::from_iter(10..30) .into_shape_with_order((5, 4)) .unwrap(); - itertools::assert_equal(a.axis_windows(Axis(0), 2), vec![ + itertools::assert_equal( + a.axis_windows(Axis(0), 2), + vec![ arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]), arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), - ]); + ], + ); } /// Simple test for iterating 3d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_3d() -{ +fn test_axis_windows_3d() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); - itertools::assert_equal(a.axis_windows(Axis(1), 2), vec![ - arr3(&[ - [[0, 1, 2], [3, 4, 5]], - [[9, 10, 11], [12, 13, 14]], - [[18, 19, 20], [21, 22, 23]], - ]), - arr3(&[ - [[3, 4, 5], [6, 7, 8]], - [[12, 13, 14], [15, 16, 17]], - [[21, 22, 23], [24, 25, 26]], - ]), - ]); + itertools::assert_equal( + a.axis_windows(Axis(1), 2), + vec![ + arr3(&[[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]]), + arr3(&[[[3, 4, 5], [6, 7, 8]], [[12, 13, 14], [15, 16, 17]], [[21, 22, 23], [24, 25, 26]]]), + ], + ); } #[test] -fn tests_axis_windows_3d_zips_with_1d() -{ +fn tests_axis_windows_3d_zips_with_1d() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -289,14 +285,13 @@ fn tests_axis_windows_3d_zips_with_1d() .for_each(|b, a| { *b = a.sum(); }); - assert_eq!(b,arr1(&[207, 261])); + assert_eq!(b, arr1(&[207, 261])); } /// Test verifies that non existent Axis results in panic #[test] #[should_panic] -fn axis_windows_with_stride_outofbound() -{ +fn axis_windows_with_stride_outofbound() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -306,8 +301,7 @@ fn axis_windows_with_stride_outofbound() /// Test verifies that zero sizes results in panic #[test] #[should_panic] -fn axis_windows_with_stride_zero_size() -{ +fn axis_windows_with_stride_zero_size() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -317,8 +311,7 @@ fn axis_windows_with_stride_zero_size() /// Test verifies that zero stride results in panic #[test] #[should_panic] -fn axis_windows_with_stride_zero_stride() -{ +fn axis_windows_with_stride_zero_stride() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -327,8 +320,7 @@ fn axis_windows_with_stride_zero_stride() /// Test verifies that over sized windows yield nothing #[test] -fn axis_windows_with_stride_oversized() -{ +fn axis_windows_with_stride_oversized() { let a = Array::from_iter(10..37) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -338,81 +330,71 @@ fn axis_windows_with_stride_oversized() /// Simple test for iterating 1d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_with_stride_1d() -{ +fn test_axis_windows_with_stride_1d() { let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); - itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 2), vec![ - arr1(&[10, 11, 12, 13, 14]), - arr1(&[12, 13, 14, 15, 16]), - arr1(&[14, 15, 16, 17, 18]), - ]); + itertools::assert_equal( + a.axis_windows_with_stride(Axis(0), 5, 2), + vec![arr1(&[10, 11, 12, 13, 14]), arr1(&[12, 13, 14, 15, 16]), arr1(&[14, 15, 16, 17, 18])], + ); - itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 3), vec![ - arr1(&[10, 11, 12, 13, 14]), - arr1(&[13, 14, 15, 16, 17]), - ]); + itertools::assert_equal( + a.axis_windows_with_stride(Axis(0), 5, 3), + vec![arr1(&[10, 11, 12, 13, 14]), arr1(&[13, 14, 15, 16, 17])], + ); } /// Simple test for iterating 2d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_with_stride_2d() -{ +fn test_axis_windows_with_stride_2d() { let a = Array::from_iter(10..30) .into_shape_with_order((5, 4)) .unwrap(); - itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 1), vec![ - arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), - arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]), - arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), - arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), - ]); - - itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 2), vec![ - arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), - arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), - ]); - - itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 3), vec![ - arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), - arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), - ]); + itertools::assert_equal( + a.axis_windows_with_stride(Axis(0), 2, 1), + vec![ + arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), + arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]), + arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), + arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), + ], + ); + + itertools::assert_equal( + a.axis_windows_with_stride(Axis(0), 2, 2), + vec![arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]])], + ); + + itertools::assert_equal( + a.axis_windows_with_stride(Axis(0), 2, 3), + vec![arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]])], + ); } /// Simple test for iterating 3d-arrays via `Axis Windows`. #[test] -fn test_axis_windows_with_stride_3d() -{ +fn test_axis_windows_with_stride_3d() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); - itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 1), vec![ - arr3(&[ - [[0, 1, 2], [3, 4, 5]], - [[9, 10, 11], [12, 13, 14]], - [[18, 19, 20], [21, 22, 23]], - ]), - arr3(&[ - [[3, 4, 5], [6, 7, 8]], - [[12, 13, 14], [15, 16, 17]], - [[21, 22, 23], [24, 25, 26]], - ]), - ]); - - itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 2), vec![ - arr3(&[ - [[0, 1, 2], [3, 4, 5]], - [[9, 10, 11], [12, 13, 14]], - [[18, 19, 20], [21, 22, 23]], - ]), - ]); + itertools::assert_equal( + a.axis_windows_with_stride(Axis(1), 2, 1), + vec![ + arr3(&[[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]]), + arr3(&[[[3, 4, 5], [6, 7, 8]], [[12, 13, 14], [15, 16, 17]], [[21, 22, 23], [24, 25, 26]]]), + ], + ); + + itertools::assert_equal( + a.axis_windows_with_stride(Axis(1), 2, 2), + vec![arr3(&[[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]])], + ); } #[test] -fn tests_axis_windows_with_stride_3d_zips_with_1d() -{ +fn tests_axis_windows_with_stride_3d_zips_with_1d() { let a = Array::from_iter(0..27) .into_shape_with_order((3, 3, 3)) .unwrap(); @@ -424,19 +406,18 @@ fn tests_axis_windows_with_stride_3d_zips_with_1d() .for_each(|b, a| { *b = a.sum(); }); - assert_eq!(b1,arr1(&[207, 261])); + assert_eq!(b1, arr1(&[207, 261])); Zip::from(b2.view_mut()) .and(a.axis_windows_with_stride(Axis(1), 2, 2)) .for_each(|b, a| { *b = a.sum(); }); - assert_eq!(b2,arr1(&[207])); + assert_eq!(b2, arr1(&[207])); } #[test] -fn test_window_neg_stride() -{ +fn test_window_neg_stride() { let array = Array::from_iter(1..10) .into_shape_with_order((3, 3)) .unwrap(); @@ -466,26 +447,31 @@ fn test_window_neg_stride() } #[test] -fn test_windows_with_stride_on_inverted_axis() -{ +fn test_windows_with_stride_on_inverted_axis() { let mut array = Array::from_iter(1..17) .into_shape_with_order((4, 4)) .unwrap(); // inverting axis results in negative stride array.invert_axis(Axis(0)); - itertools::assert_equal(array.windows_with_stride((2, 2), (2, 2)), vec![ + itertools::assert_equal( + array.windows_with_stride((2, 2), (2, 2)), + vec![ arr2(&[[13, 14], [9, 10]]), arr2(&[[15, 16], [11, 12]]), arr2(&[[5, 6], [1, 2]]), arr2(&[[7, 8], [3, 4]]), - ]); + ], + ); array.invert_axis(Axis(1)); - itertools::assert_equal(array.windows_with_stride((2, 2), (2, 2)), vec![ + itertools::assert_equal( + array.windows_with_stride((2, 2), (2, 2)), + vec![ arr2(&[[16, 15], [12, 11]]), arr2(&[[14, 13], [10, 9]]), arr2(&[[8, 7], [4, 3]]), arr2(&[[6, 5], [2, 1]]), - ]); + ], + ); } diff --git a/tests/zst.rs b/tests/zst.rs index f5f2c8e3..c3c779d2 100644 --- a/tests/zst.rs +++ b/tests/zst.rs @@ -2,8 +2,7 @@ use ndarray::arr2; use ndarray::ArcArray; #[test] -fn test_swap() -{ +fn test_swap() { let mut a = arr2(&[[(); 3]; 3]); let b = a.clone(); @@ -17,8 +16,7 @@ fn test_swap() } #[test] -fn test() -{ +fn test() { let c = ArcArray::<(), _>::default((3, 4)); let mut d = c.clone(); for _ in d.iter_mut() {}