From 4bb45f82a27c4fc0aafd08f14468b73a8b9fa4a2 Mon Sep 17 00:00:00 2001 From: Dikshant <121669947+pingu-73@users.noreply.github.com> Date: Sun, 1 Sep 2024 22:05:42 +0530 Subject: [PATCH] Bindings for hpx/algorithm.hpp (#7) * added hpx::copy * added hpx::copy_n * added hpx::copy_if * added hpx::count, hpx::count_if * rerun tests * added hpx::ends_with * added hpx::equal * added hpx::fill but it only works for 1D vector * added hpx::find * added hpx::sort * added hpx::sort along with comparator closure as an argument * added hpx::merge * added hpx::partial_sort * moved wrappers from tests to main * removed redundant copy from copy, sort * removed redundant copy from copy_n * removed redundant copy from find, fill, sort_comp * removed redundant copy from merge, partial sort * removed redundant copy from copy_if * rerun tests --------- Signed-off-by: Dikshant --- Cargo.toml | 2 +- hpx-sys/include/wrapper.h | 135 ++++++++++- hpx-sys/src/lib.rs | 466 +++++++++++++++++++++++++++++++++++++- 3 files changed, 580 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8cf8c32..2ecd021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hpx-rs" version = "0.1.0" -authors = ["Shreyas Atre ", "Dikshant ", "Dikshant "] edition = "2021" readme = "README.md" repository = "https://github.com/STEllAR-GROUP/hpx-rs" diff --git a/hpx-sys/include/wrapper.h b/hpx-sys/include/wrapper.h index 91cf395..6889219 100644 --- a/hpx-sys/include/wrapper.h +++ b/hpx-sys/include/wrapper.h @@ -1,23 +1,13 @@ #pragma once #include +#include #include #include #include #include "rust/cxx.h" - -/*inline std::int32_t start() { return hpx::start(nullptr, 0, nullptr); }*/ - -/*inline std::int32_t start(rust::Fn rust_fn, int argc, char **argv) {*/ -/* return hpx::start(*/ -/* [&](int argc, char **argv) {*/ -/* return rust_fn(argc, argv);*/ -/* },*/ -/* argc, argv);*/ -/*}*/ - inline std::int32_t init(rust::Fn rust_fn, int argc, char **argv) { return hpx::init( [&](int argc, char **argv) { @@ -44,4 +34,125 @@ inline std::int32_t disconnect_with_timeout(double shutdown_timeout, double loca inline std::int32_t finalize() { return hpx::finalize(); } -/*inline std::int32_t stop() { return hpx::stop(); }*/ +inline void hpx_copy(rust::Slice src, rust::Slice dest) { + hpx::copy(hpx::execution::par, src.begin(), src.end(), dest.begin()); +} + +inline void hpx_copy_n(rust::Slice src, size_t count, rust::Slice dest) { + hpx::copy_n(hpx::execution::par, src.begin(), count, dest.begin()); +} + +inline void hpx_copy_if(const rust::Vec& src, rust::Vec& dest, + rust::Fn pred) { + std::vector cpp_dest(src.size()); + + auto result = hpx::copy_if(hpx::execution::par, + src.begin(), src.end(), + cpp_dest.begin(), + [&](int32_t value) { return pred(value); }); + + cpp_dest.resize(std::distance(cpp_dest.begin(), result)); + + dest.clear(); + dest.reserve(cpp_dest.size()); + for (const auto& item : cpp_dest) { + dest.push_back(item); + } +} + +inline std::int64_t hpx_count(const rust::Vec& vec, int32_t value) { + return hpx::count(hpx::execution::par, vec.begin(), vec.end(), value); +} + + +inline int64_t hpx_count_if(const rust::Vec& vec, rust::Fn pred) { + std::vector cpp_vec(vec.begin(), vec.end()); + + auto result = hpx::count_if(hpx::execution::par, + cpp_vec.begin(), + cpp_vec.end(), + [&](int32_t value) { return pred(value); }); + + return static_cast(result); +} + +inline bool hpx_ends_with(rust::Slice src, + rust::Slice dest) { + return hpx::ends_with(hpx::execution::par, + src.begin(), src.end(), + dest.begin(), dest.end(), + std::equal_to()); +} + +inline bool hpx_equal(rust::Slice src, rust::Slice dest) { + return hpx::equal( + hpx::execution::par, + src.begin(), src.end(), + dest.begin(), dest.end() + ); +} + +inline void hpx_fill(rust::Slice src, int32_t value) { + hpx::fill(hpx::execution::par, src.begin(), src.end(), value); +} + +inline int64_t hpx_find(rust::Slice src, int32_t value) { + auto result = hpx::find(hpx::execution::par, + src.begin(), + src.end(), + value); + + if (result != src.end()) { + return static_cast(std::distance(src.begin(), result)); + } + return -1; +} + +inline void hpx_sort(rust::Slice src) { + hpx::sort(hpx::execution::par, src.begin(), src.end()); +} + +inline void hpx_sort_comp(rust::Vec& src, rust::Fn comp) { + hpx::sort(hpx::execution::par, src.begin(), src.end(), + [&](int32_t a, int32_t b) { return comp(a, b); }); +} + +inline void hpx_merge(rust::Slice src1, + rust::Slice src2, + rust::Vec& dest) { + dest.clear(); + dest.reserve(src1.size() + src2.size()); + + for (size_t i = 0; i < src1.size() + src2.size(); ++i) { + dest.push_back(0); + } + + hpx::merge(hpx::execution::par, + src1.begin(), src1.end(), + src2.begin(), src2.end(), + dest.begin()); +} + +inline void hpx_partial_sort(rust::Vec& src, size_t last) { + if (last > src.size()) { + last = src.size(); + } + + hpx::partial_sort(hpx::execution::par, + src.begin(), + src.begin() + last, + src.end()); +} + +inline void hpx_partial_sort_comp(rust::Vec& src, size_t last, + rust::Fn comp) { + if (last > src.size()) { + last = src.size(); + } + + hpx::partial_sort(hpx::execution::par, + src.begin(), + src.begin() + last, + src.end(), + [&](int32_t a, int32_t b) { return comp(a, b); }); +} diff --git a/hpx-sys/src/lib.rs b/hpx-sys/src/lib.rs index 0d2e98a..c6ecc48 100644 --- a/hpx-sys/src/lib.rs +++ b/hpx-sys/src/lib.rs @@ -7,7 +7,6 @@ pub mod ffi { unsafe extern "C++" { include!("hpx-sys/include/wrapper.h"); - //fn start() -> i32; unsafe fn init( func: unsafe fn(i32, *mut *mut c_char) -> i32, argc: i32, @@ -19,29 +18,87 @@ pub mod ffi { fn terminate(); fn disconnect() -> i32; fn disconnect_with_timeout(shutdown_timeout: f64, localwait: f64) -> i32; - //fn stop() -> i32; + fn hpx_copy(src: &[i32], dest: &mut [i32]); + fn hpx_copy_n(src: &[i32], count: usize, dest: &mut [i32]); + fn hpx_copy_if(src: &Vec, dest: &mut Vec, pred: fn(i32) -> bool); + fn hpx_count(src: &Vec, value: i32) -> i64; + fn hpx_count_if(src: &Vec, pred: fn(i32) -> bool) -> i64; + fn hpx_ends_with(src: &[i32], dest: &[i32]) -> bool; + fn hpx_equal(slice1: &[i32], slice2: &[i32]) -> bool; + fn hpx_fill(src: &mut [i32], value: i32); // will only work for linear vectors + fn hpx_find(src: &[i32], value: i32) -> i64; + fn hpx_sort(src: &mut [i32]); + fn hpx_sort_comp(src: &mut Vec, comp: fn(i32, i32) -> bool); + fn hpx_merge(src1: &[i32], src2: &[i32], dest: &mut Vec); + fn hpx_partial_sort(src: &mut Vec, last: usize); + fn hpx_partial_sort_comp(src: &mut Vec, last: usize, comp: fn(i32, i32) -> bool); } } +// ================================================================================================ +// Wrapper for the above Bindings. +// reffer to tests to understand how to use them. [NOTE: Not all bindings have wrapper.] +// ================================================================================================ +use std::ffi::CString; +use std::os::raw::c_char; + +pub fn create_c_args(args: &[&str]) -> (i32, Vec<*mut c_char>) { + let c_args: Vec = args.iter().map(|s| CString::new(*s).unwrap()).collect(); + let ptrs: Vec<*mut c_char> = c_args.iter().map(|s| s.as_ptr() as *mut c_char).collect(); + (ptrs.len() as i32, ptrs) +} + +pub fn copy_vector(src: &[i32]) -> Vec { + let mut dest = vec![0; src.len()]; + ffi::hpx_copy(src, &mut dest); + dest +} + +pub fn copy_n(src: &[i32], count: usize) -> Result, &'static str> { + if count > src.len() { + return Err("count larger than source slice length"); + } + let mut dest = vec![0; count]; + ffi::hpx_copy_n(src, count, &mut dest); + Ok(dest) +} + +pub fn copy_if_divisiblileityby3(src: &Vec) -> Vec { + let mut dest = Vec::new(); + ffi::hpx_copy_if(src, &mut dest, |x| x % 3 == 0); + dest +} + +pub fn count(vec: &Vec, value: i32) -> i64 { + ffi::hpx_count(vec, value) +} + +pub fn find(slice: &[i32], value: i32) -> Option { + match ffi::hpx_find(slice, value) { + -1 => None, + index => Some(index as usize), + } +} + +pub fn merge(src1: &[i32], src2: &[i32]) -> Vec { + let mut dest = Vec::new(); + ffi::hpx_merge(src1, src2, &mut dest); + dest +} + // ================================================================================================ // Tests (to be shifted to systests crate within hpx-rs workspace) // ================================================================================================ #[cfg(test)] mod tests { use super::ffi; + use crate::{copy_if_divisiblileityby3, copy_n, copy_vector, count, create_c_args, find}; use serial_test::serial; use std::ffi::CString; use std::os::raw::c_char; - use std::thread; - use std::time::Duration; - - fn create_c_args(args: &[&str]) -> (i32, Vec<*mut c_char>) { - let c_args: Vec = args.iter().map(|s| CString::new(*s).unwrap()).collect(); - let ptrs: Vec<*mut c_char> = c_args.iter().map(|s| s.as_ptr() as *mut c_char).collect(); - (ptrs.len() as i32, ptrs) - } #[test] + #[serial] fn test_init_finalize() { let (argc, mut argv) = create_c_args(&["testing", "arg1", "arg2"]); @@ -57,4 +114,393 @@ mod tests { assert_eq!(result, 0); } } + + #[test] + #[serial] + fn test_hpx_copy() { + let (argc, mut argv) = create_c_args(&["test_hpx_copy"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let src = vec![1, 2, 3, 4, 5]; + let result = copy_vector(&src); + assert_eq!(src, result); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_copy_range() { + let (argc, mut argv) = create_c_args(&["test_hpx_copy"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let src = vec![1, 2, 3, 4, 5]; + let result = copy_vector(&src[0..3]); + assert_eq!(&src[0..3], &result); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_copy_n() { + let (argc, mut argv) = create_c_args(&["test_copy_n"]); + + let test_func = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let src = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + match copy_n(&src, 5) { + Ok(result) => assert_eq!(result, vec![1, 2, 3, 4, 5]), + Err(e) => panic!("Unexpected error: {}", e), + } + + match copy_n(&src, 15) { + // expecting error + Ok(_) => panic!("Expected error, but got Ok"), + Err(e) => assert_eq!(e, "count larger than source slice length"), + } + ffi::finalize() + }; + + unsafe { + let result = ffi::init(test_func, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_copy_if() { + let (argc, mut argv) = create_c_args(&["test_hpx_copy_if"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let src = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let result = copy_if_divisiblileityby3(&src); + assert_eq!(result, vec![0, 3, 6, 9, 12]); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_count() { + let (argc, mut argv) = create_c_args(&["test_hpx_count"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let vec = vec![1, 2, 3, 2, 4, 2, 5, 2]; + let result = count(&vec, 2); + assert_eq!(result, 4); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_count_if() { + let (argc, mut argv) = create_c_args(&["test_hpx_count_if"]); + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let result_even = ffi::hpx_count_if(&vec, |x| x % 2 == 0); + assert_eq!(result_even, 5); + let result_greater_than_5 = ffi::hpx_count_if(&vec, |x| x > 5); + assert_eq!(result_greater_than_5, 5); + let is_prime = |n: i32| { + if n <= 1 { + return false; + } + for i in 2..=(n as f64).sqrt() as i32 { + if n % i == 0 { + return false; + } + } + true + }; + let result_prime = ffi::hpx_count_if(&vec, is_prime); + assert_eq!(result_prime, 4); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_ends_with() { + let (argc, mut argv) = create_c_args(&["test_hpx_ends_with"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let v1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let v2 = vec![8, 9, 10]; + let v3 = vec![7, 8, 9]; + let v4 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let v5: Vec = vec![]; + + // passing vectors + assert!(ffi::hpx_ends_with(&v1, &v2)); + assert!(!ffi::hpx_ends_with(&v1, &v3)); + assert!(ffi::hpx_ends_with(&v1, &v4)); + assert!(ffi::hpx_ends_with(&v1, &v5)); + assert!(ffi::hpx_ends_with(&v5, &v5)); + + // passing slices + assert!(ffi::hpx_ends_with(&v1[5..], &v2)); + assert!(ffi::hpx_ends_with(&v1[..], &v1[8..])); + assert!(!ffi::hpx_ends_with(&v1[..5], &v2)); + assert!(ffi::hpx_ends_with(&v1[..5], &v1[3..5])); + + assert!(ffi::hpx_ends_with(&v1, &[])); + assert!(ffi::hpx_ends_with(&[], &[])); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_equal() { + let (argc, mut argv) = create_c_args(&["test_hpx_equal"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let v1 = vec![1, 2, 3, 4, 5]; + let v2 = vec![1, 2, 3, 4, 5]; + let v3 = vec![1, 2, 3, 4, 6]; + let v4 = vec![1, 2, 3, 4]; + let v5 = vec![0, 1, 2, 3, 4, 5, 6]; + + // passing vectors + assert!(ffi::hpx_equal(&v1, &v2)); + assert!(!ffi::hpx_equal(&v1, &v3)); + assert!(!ffi::hpx_equal(&v1, &v4)); + + // passing slices + assert!(ffi::hpx_equal(&v1[..], &v2[..])); + assert!(ffi::hpx_equal(&v1[1..4], &v2[1..4])); + assert!(ffi::hpx_equal(&v1[..3], &v4[..3])); + assert!(ffi::hpx_equal(&v1[..], &v5[1..6])); + + assert!(ffi::hpx_equal(&v1, &v5[1..6])); + assert!(!ffi::hpx_equal(&v1[..4], &v3)); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_fill() { + let (argc, mut argv) = create_c_args(&["test_hpx_fill"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut v = vec![0; 10]; + ffi::hpx_fill(&mut v, 42); + assert!(v.iter().all(|&x| x == 42)); + + let mut v2 = vec![0; 1_000_000]; // testing on a long vector + ffi::hpx_fill(&mut v2, 7); + assert!(v2.iter().all(|&x| x == 7)); + + let mut v3: Vec = Vec::new(); + ffi::hpx_fill(&mut v3, 100); + assert!(v3.is_empty()); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_find() { + let (argc, mut argv) = create_c_args(&["test_hpx_find"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + let result = find(&v, 5); // finding existing value + assert_eq!(result, Some(4)); + + let result = find(&v, 11); // finding non-existing value + assert_eq!(result, None); + + let result = find(&v, 1); + assert_eq!(result, Some(0)); + + let result = find(&v, 10); + assert_eq!(result, Some(9)); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_sort() { + let (argc, mut argv) = create_c_args(&["test_hpx_sort"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut src = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + ffi::hpx_sort(&mut src); + assert_eq!(src, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_sort_comp() { + let (argc, mut argv) = create_c_args(&["test_hpx_sort_comp"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut v = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + ffi::hpx_sort_comp(&mut v, |a, b| a > b); // sorting in descending order + assert_eq!(v, vec![9, 8, 7, 6, 5, 4, 3, 2, 1]); + + // sorting even numbers before odd numbers + let mut v2 = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + ffi::hpx_sort_comp( + &mut v2, + |a, b| { + if a % 2 == b % 2 { + a < b + } else { + a % 2 == 0 + } + }, + ); + assert_eq!(v2, vec![2, 4, 6, 8, 1, 3, 5, 7, 9]); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_merge() { + let (argc, mut argv) = create_c_args(&["test_hpx_merge"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let v1 = vec![1, 3, 5, 7, 9, 20, 100]; + let v2 = vec![2, 4, 6, 8, 10, 97]; + let mut dest = Vec::new(); + ffi::hpx_merge(&v1, &v2, &mut dest); + assert_eq!(dest, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 97, 100]); + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_partial_sort() { + let (argc, mut argv) = create_c_args(&["test_hpx_partial_sort"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut vec = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + let last = 4; + println!("Before partial sort: {:?}", vec); + + ffi::hpx_partial_sort(&mut vec, last); + println!("After partial sort: {:?}", vec); + + // If first -> last elements are sorted + assert!(vec[..last].windows(2).all(|w| w[0] <= w[1])); + + // If ele of sorted part <= ele of unsorted part + assert!(vec[..last] + .iter() + .all(|&x| vec[last..].iter().all(|&y| x <= y))); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_partial_sort_comp() { + let (argc, mut argv) = create_c_args(&["test_hpx_partial_sort_comp"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut vec = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + let last = 4; + println!("Before partial sort: {:?}", vec); + + ffi::hpx_partial_sort_comp(&mut vec, last, |a, b| b < a); + println!("After partial sort: {:?}", vec); + + // If first -> last elements are sorted dec + assert!(vec[..last].windows(2).all(|w| w[0] >= w[1])); + + // If ele of sorted part >= ele of unsorted part + assert!(vec[..last] + .iter() + .all(|&x| vec[last..].iter().all(|&y| x >= y))); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } }