Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/file_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,37 @@ void FileManerger::setFileName(const std::string& value) {
std::unique_lock<std::shared_mutex> lock(mutex_);
file_name_ = value;
}

void FileManerger::captureStdout(std::function<void()> func) {
std::unique_lock<std::shared_mutex> lock(mutex_);

if (!file_stream_.is_open()) {
throw std::runtime_error(
"File stream is not open. Call createFile() first.");
}

// 保存原来的 cout buffer
std::streambuf* old_cout_buf = std::cout.rdbuf();

// 创建一个 stringstream 来捕获输出
std::stringstream captured_output;

// 重定向 cout 到 stringstream
std::cout.rdbuf(captured_output.rdbuf());

try {
// 执行函数
func();

// 恢复 cout
std::cout.rdbuf(old_cout_buf);

// 将捕获的输出写入文件
file_stream_ << captured_output.str();
} catch (...) {
// 确保恢复 cout
std::cout.rdbuf(old_cout_buf);
throw;
}
}
} // namespace paddle_api_test
4 changes: 4 additions & 0 deletions src/file_manager.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <fstream>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <string>
Expand All @@ -16,6 +17,9 @@ class FileManerger {
FileManerger& operator<<(const std::string& str);
void saveFile();

// 捕获标准输出到文件
void captureStdout(std::function<void()> func);

private:
mutable std::shared_mutex mutex_;
std::string basic_path_ = "/tmp/paddle_cpp_api_test/";
Expand Down
132 changes: 132 additions & 0 deletions test/TensorUtilTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/ops/ones.h>
#include <gtest/gtest.h>
#include <torch/all.h>

#include <string>
#include <vector>

#include "../src/file_manager.h"

extern paddle_api_test::ThreadSafeParam g_custom_param;

namespace at {
namespace test {

using paddle_api_test::FileManerger;
using paddle_api_test::ThreadSafeParam;

class TensorUtilTest : public ::testing::Test {
protected:
void SetUp() override {
std::vector<int64_t> shape = {2, 3, 4};
tensor = at::ones(shape, at::kFloat);
}

at::Tensor tensor;
};

// 测试 toString
TEST_F(TensorUtilTest, ToString) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
std::string tensor_str = tensor.toString();
file << tensor_str << " ";
file.saveFile();
}

// 测试 is_contiguous_or_false
TEST_F(TensorUtilTest, IsContiguousOrFalse) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
file << std::to_string(tensor.is_contiguous_or_false()) << " ";

// 测试非连续的tensor
at::Tensor transposed = tensor.transpose(0, 2);
file << std::to_string(transposed.is_contiguous_or_false()) << " ";
file.saveFile();
}

// 测试 is_same
TEST_F(TensorUtilTest, IsSame) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Test that tensor is same as itself
file << std::to_string(tensor.is_same(tensor)) << " ";

// Test that two different tensors are not the same
at::Tensor other_tensor = at::ones({2, 3, 4}, at::kFloat);
file << std::to_string(tensor.is_same(other_tensor)) << " ";

// Test that a shallow copy points to the same tensor
at::Tensor shallow_copy = tensor;
file << std::to_string(tensor.is_same(shallow_copy)) << " ";

// Test that a view of the tensor
at::Tensor view = tensor.view({24});
file << std::to_string(tensor.is_same(view)) << " ";
file.saveFile();
}

// 测试 use_count
TEST_F(TensorUtilTest, UseCount) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Get initial use count
size_t initial_count = tensor.use_count();
file << std::to_string(initial_count) << " ";

// Create a copy, should increase use count
{
at::Tensor copy = tensor;
size_t new_count = tensor.use_count();
file << std::to_string(new_count) << " ";
file << std::to_string(new_count - initial_count) << " "; // 差值
}

// After copy goes out of scope, use count should decrease
size_t final_count = tensor.use_count();
file << std::to_string(final_count) << " ";
file.saveFile();
}

// 测试 weak_use_count
TEST_F(TensorUtilTest, WeakUseCount) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Get initial weak use count
size_t initial_weak_count = tensor.weak_use_count();
file << std::to_string(initial_weak_count) << " ";
file.saveFile();
}

// 测试 print
TEST_F(TensorUtilTest, Print) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// 创建一个小的tensor用于print测试
at::Tensor small_tensor = at::ones({2, 2}, at::kFloat);

// 使用 captureStdout 捕获 print() 的输出
file.captureStdout([&]() {
tensor.print();
small_tensor.print();
});

file << std::to_string(1) << " "; // 如果执行到这里说明print()没有崩溃
file.saveFile();
}

} // namespace test
} // namespace at