Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ endif()
if(TRX_LIBZIP_TARGET AND TARGET ${TRX_LIBZIP_TARGET})
message(STATUS "trx-cpp: using provided libzip target: ${TRX_LIBZIP_TARGET}")
else()
find_package(libzip QUIET)
if(NOT libzip_FOUND)
message(STATUS "libzip not found; fetching v1.11.4")
# CONFIG-only: only accept a system libzip that provides a proper cmake
# config (creating libzip::zip). If not found, always fetch and build.
find_package(libzip CONFIG QUIET)
if(NOT TARGET libzip::zip)
message(STATUS "libzip not found via config-mode find_package; fetching v1.11.4")
set(LIBZIP_DO_INSTALL OFF)
set(BUILD_TOOLS OFF)
set(BUILD_REGRESS OFF)
Expand Down Expand Up @@ -103,7 +105,7 @@ else()
set_target_properties(zip PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
endif()
# Detect the target name: find_package creates libzip::zip;
# Detect the target name: CONFIG find_package creates libzip::zip;
# FetchContent creates the bare 'zip' target.
if(TARGET libzip::zip)
set(TRX_LIBZIP_TARGET libzip::zip)
Expand Down
49 changes: 48 additions & 1 deletion include/trx/trx.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <json11.hpp>
#include <limits>
#include <memory>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string_view>
Expand Down Expand Up @@ -46,6 +47,8 @@ using json = json11::Json;
namespace trx {
enum class TrxSaveMode { Auto, Archive, Directory };

enum class TrxScalarType { Float16, Float32, Float64 };

enum class ConnectivityMeasure {
StreamlineCount,
DpsSum
Expand All @@ -62,6 +65,9 @@ struct TrxSaveOptions {
TrxSaveMode mode = TrxSaveMode::Auto;
size_t memory_limit_bytes = 0; // Reserved for future save-path tuning.
bool overwrite_existing = true;
/// When set, the positions array is converted to this dtype on output.
/// All other data (dps, dpv, groups, offsets) is preserved unchanged.
std::optional<TrxScalarType> positions_dtype;
};

inline json::object _json_object(const json &value) {
Expand Down Expand Up @@ -553,6 +559,32 @@ template <typename DT> class TrxFile {

namespace detail {
int _sizeof_dtype(const std::string &dtype);

/// RAII guard that deletes a temporary file on scope exit.
struct TempFileGuard {
std::string path;
TempFileGuard() = default;
TempFileGuard(const TempFileGuard &) = delete;
TempFileGuard &operator=(const TempFileGuard &) = delete;
~TempFileGuard() {
if (!path.empty()) {
std::error_code ec;
trx::fs::remove(path, ec);
}
}
};

/// Return a unique path in the OS temp directory suitable for a scratch file.
inline std::string make_unique_temp_path(const std::string &prefix) {
std::error_code ec;
trx::fs::path tmp = trx::fs::temp_directory_path(ec);
if (ec)
tmp = trx::fs::path(".");
thread_local std::mt19937_64 rng(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist;
return (tmp / (prefix + "_" + std::to_string(dist(rng)) + ".bin")).string();
}

} // namespace detail

struct TypedArray {
Expand Down Expand Up @@ -959,7 +991,22 @@ template <typename DT> std::unique_ptr<TrxFile<DT>> load_from_directory(const st
*/
std::string detect_positions_dtype(const std::string &path);

enum class TrxScalarType { Float16, Float32, Float64 };
/**
* @brief Convert and write positions from @p source to @p out_path in @p target_dtype.
*
* Reads positions chunk-by-chunk (each chunk at most @p chunk_bytes bytes of source data)
* and writes converted raw bytes to @p out_path. Peak transient memory is bounded to
* roughly one source chunk plus one destination chunk — independent of file size.
*
* @param source TRX file whose positions will be converted.
* @param target_dtype Desired output scalar type.
* @param out_path Path to write the raw converted positions binary.
* @param chunk_bytes Bytes of source data per conversion pass (default 64 MiB).
*/
void write_positions_as_dtype(const AnyTrxFile &source,
TrxScalarType target_dtype,
const std::string &out_path,
size_t chunk_bytes = 64 * 1024 * 1024);

/**
* @brief Return the canonical string name for a TrxScalarType.
Expand Down
49 changes: 48 additions & 1 deletion include/trx/trx.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,38 @@ template <typename DT> void TrxFile<DT>::save(const std::string &filename, const
if (!zf) {
throw TrxIOError("Could not open archive " + filename + ": " + strerror(errorp));
}
zip_from_folder(zf.get(), tmp_dir_name, tmp_dir_name, options.compression_standard, nullptr);

std::unordered_set<std::string> skip;
detail::TempFileGuard tmp_pos_guard;
if (options.positions_dtype.has_value() && save_trx->streamlines) {
const TrxScalarType target = *options.positions_dtype;
const std::string cur_dtype = detect_positions_dtype(tmp_dir_name);
const std::string new_dtype_str = scalar_type_name(target);
if (!cur_dtype.empty() && cur_dtype != new_dtype_str) {
skip.insert("positions.3." + cur_dtype);
tmp_pos_guard.path = detail::make_unique_temp_path("trx_pos_convert");
{
auto src_any = load_any(tmp_dir_name);
write_positions_as_dtype(src_any, target, tmp_pos_guard.path);
}
const std::string new_pos_name = "positions.3." + new_dtype_str;
zip_source_t *pos_src = zip_source_file(zf.get(), tmp_pos_guard.path.c_str(), 0, -1);
if (!pos_src)
throw TrxIOError("Failed to create zip source for converted positions");
const zip_int64_t pos_idx = zip_file_add(
zf.get(), new_pos_name.c_str(), pos_src, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE);
if (pos_idx < 0)
throw TrxIOError("Failed to add converted positions to archive");
if (zip_set_file_compression(
zf.get(), pos_idx,
static_cast<zip_int32_t>(options.compression_standard), 0) < 0)
throw TrxIOError("Failed to set compression for converted positions");
}
}
zip_from_folder(zf.get(), tmp_dir_name, tmp_dir_name, options.compression_standard,
skip.empty() ? nullptr : &skip);
zf.commit(filename);
// tmp_pos_guard destructor removes the temp file after commit.
} else {
std::error_code ec;
if (!trx::fs::exists(tmp_dir_name, ec) || !trx::fs::is_directory(tmp_dir_name, ec)) {
Expand All @@ -1309,6 +1339,23 @@ template <typename DT> void TrxFile<DT>::save(const std::string &filename, const
if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) {
throw TrxIOError("Failed to create output directory: " + filename);
}

if (options.positions_dtype.has_value() && save_trx->streamlines) {
const TrxScalarType target = *options.positions_dtype;
const std::string cur_dtype = detect_positions_dtype(filename);
const std::string new_dtype_str = scalar_type_name(target);
if (!cur_dtype.empty() && cur_dtype != new_dtype_str) {
const std::string old_pos = filename + SEPARATOR + "positions.3." + cur_dtype;
const std::string new_pos = filename + SEPARATOR + "positions.3." + new_dtype_str;
{
auto src_any = load_any(filename);
write_positions_as_dtype(src_any, target, new_pos);
}
std::error_code rm_ec;
trx::fs::remove(old_pos, rm_ec);
}
}

const trx::fs::path header_path = dest_path / "header.json";
if (!trx::fs::exists(header_path)) {
throw TrxFormatError("Missing header.json in output directory: " + header_path.string());
Expand Down
100 changes: 99 additions & 1 deletion src/trx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,75 @@ AnyTrxFile::_create_from_pointer(json header,
return trx;
}

void write_positions_as_dtype(const AnyTrxFile &source,
TrxScalarType target_dtype,
const std::string &out_path,
size_t chunk_bytes) {
std::ofstream out(out_path, std::ios::binary | std::ios::trunc);
if (!out)
throw TrxIOError("Failed to create positions output: " + out_path);

source.for_each_positions_chunk(
chunk_bytes,
[&](TrxScalarType src_dtype, const void *data, size_t /*point_offset*/, size_t point_count) {
const size_t n = point_count * 3;

// Inner lambda: read from typed source pointer, cast to DstT, write to stream.
auto write_as = [&](auto typed_src) {
switch (target_dtype) {
case TrxScalarType::Float16: {
std::vector<Eigen::half> buf(n);
for (size_t i = 0; i < n; ++i)
buf[i] = static_cast<Eigen::half>(static_cast<float>(typed_src[i]));
out.write(reinterpret_cast<const char *>(buf.data()),
static_cast<std::streamsize>(n * sizeof(Eigen::half)));
break;
}
case TrxScalarType::Float64: {
std::vector<double> buf(n);
for (size_t i = 0; i < n; ++i)
buf[i] = static_cast<double>(typed_src[i]);
out.write(reinterpret_cast<const char *>(buf.data()),
static_cast<std::streamsize>(n * sizeof(double)));
break;
}
default: {
std::vector<float> buf(n);
for (size_t i = 0; i < n; ++i)
buf[i] = static_cast<float>(typed_src[i]);
out.write(reinterpret_cast<const char *>(buf.data()),
static_cast<std::streamsize>(n * sizeof(float)));
break;
}
}
};

switch (src_dtype) {
case TrxScalarType::Float16:
write_as(reinterpret_cast<const Eigen::half *>(data));
break;
case TrxScalarType::Float64:
write_as(reinterpret_cast<const double *>(data));
break;
default:
write_as(reinterpret_cast<const float *>(data));
break;
}
});

if (out.bad())
throw TrxIOError("I/O error writing converted positions to: " + out_path);
}

void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_standard) {
TrxSaveOptions options;
options.compression_standard = compression_standard;
save(filename, options);
}

using trx::detail::TempFileGuard;
using trx::detail::make_unique_temp_path;

void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options) {
const std::string ext = get_ext(filename);
const TrxSaveMode save_mode = resolve_save_mode(filename, options.mode);
Expand Down Expand Up @@ -592,7 +655,30 @@ void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options
std::string(zip_strerror(zf.get())));
}

const std::unordered_set<std::string> skip = {"header.json"};
std::unordered_set<std::string> skip = {"header.json"};
// Guard deletes the temp positions file after commit (or on exception).
TempFileGuard tmp_pos_guard;
if (options.positions_dtype.has_value() && !positions.empty()) {
const TrxScalarType target = *options.positions_dtype;
const std::string new_dtype_str = scalar_type_name(target);
if (new_dtype_str != positions.dtype) {
skip.insert("positions.3." + positions.dtype);
tmp_pos_guard.path = make_unique_temp_path("trx_pos_convert");
write_positions_as_dtype(*this, target, tmp_pos_guard.path);
const std::string new_pos_name = "positions.3." + new_dtype_str;
zip_source_t *pos_src =
zip_source_file(zf.get(), tmp_pos_guard.path.c_str(), 0, -1);
if (!pos_src)
throw TrxIOError("Failed to create zip source for converted positions");
const zip_int64_t pos_idx =
zip_file_add(zf.get(), new_pos_name.c_str(), pos_src, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE);
if (pos_idx < 0)
throw TrxIOError("Failed to add converted positions to archive: " +
std::string(zip_strerror(zf.get())));
if (zip_set_file_compression(zf.get(), pos_idx, compression, 0) < 0)
throw TrxIOError("Failed to set compression for converted positions");
}
}
zip_from_folder(zf.get(), source_dir, source_dir, options.compression_standard, &skip);
zf.commit(filename);
} else {
Expand All @@ -619,6 +705,18 @@ void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options
copy_dir(source_dir, filename);
}

if (options.positions_dtype.has_value() && !positions.empty()) {
const TrxScalarType target = *options.positions_dtype;
const std::string new_dtype_str = scalar_type_name(target);
if (new_dtype_str != positions.dtype) {
const std::string old_pos = filename + SEPARATOR + "positions.3." + positions.dtype;
const std::string new_pos = filename + SEPARATOR + "positions.3." + new_dtype_str;
write_positions_as_dtype(*this, target, new_pos);
std::error_code rm_ec;
trx::fs::remove(old_pos, rm_ec);
}
}

const trx::fs::path final_header_path = dest_path / "header.json";
std::ofstream out_json(final_header_path, std::ios::out | std::ios::trunc);
if (!out_json.is_open()) {
Expand Down
Loading
Loading