diff --git a/common/cNet.cpp b/common/cNet.cpp index 9c60770..88718a6 100644 --- a/common/cNet.cpp +++ b/common/cNet.cpp @@ -596,7 +596,7 @@ int cNet::GetOutputMemorySize(const int crop_w, const int crop_h, const int oute } // ネットワークを使って画像を再構築する -Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat) +Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat) { const auto InputHeight = inMat.size().height; const auto InputWidth = inMat.size().width; @@ -672,7 +672,7 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_ // 画像を直列に変換 { - float *fptr = inputBlockBuf + (input_block_plane_size * n); + float *fptr = input_blob->mutable_cpu_data() + (input_block_plane_size * n); const float *uptr = (const float *)someimg.data; const auto Line = someimg.step1(); @@ -712,9 +712,6 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_ assert(input_blob->count() == input_block_plane_size * processNum); - // ネットワークに画像を入力 - input_blob->set_cpu_data(inputBlockBuf); - // 計算 auto out = mNet->Forward(); diff --git a/common/cNet.h b/common/cNet.h index 1d6522f..cbd8ff1 100644 --- a/common/cNet.h +++ b/common/cNet.h @@ -57,7 +57,7 @@ public: int GetInputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const; int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const; - Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat); + Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat); static std::string GetModelName(const boost::filesystem::path &info_path); }; diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index eaaab21..7742ab6 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -249,7 +249,7 @@ void Waifu2x::quit_liblary() {} -Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mInputBlock(nullptr), mInputBlockSize(0), mOutputBlock(nullptr), mOutputBlockSize(0) +Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mOutputBlock(nullptr), mOutputBlockSize(0) {} Waifu2x::~Waifu2x() @@ -669,19 +669,6 @@ Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr net, const int { Waifu2x::eWaifu2xError ret; - const auto InputMemorySize = net->GetInputMemorySize(crop_w, crop_h, OuterPadding, batch_size); - if (InputMemorySize > mInputBlockSize) - { - if (mIsCuda) - CUDA_HOST_SAFE_FREE(mInputBlock); - else - SAFE_DELETE_WAIFU2X(mInputBlock); - - CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mInputBlock, InputMemorySize, cudaHostAllocWriteCombined)); - - mInputBlockSize = InputMemorySize; - } - const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size); if (OutputMemorySize > mOutputBlockSize) { @@ -692,10 +679,10 @@ Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr net, const int CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault)); - mInputBlockSize = OutputMemorySize; + mOutputBlockSize = OutputMemorySize; } - ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mInputBlock, mOutputBlock, im, im); + ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mOutputBlock, im, im); if (ret != Waifu2x::eWaifu2xError_OK) return ret; @@ -720,12 +707,10 @@ void Waifu2x::Destroy() if (mIsCuda) { - CUDA_HOST_SAFE_FREE(mInputBlock); CUDA_HOST_SAFE_FREE(mOutputBlock); } else { - SAFE_DELETE_WAIFU2X(mInputBlock); SAFE_DELETE_WAIFU2X(mOutputBlock); } diff --git a/common/waifu2x.h b/common/waifu2x.h index 6f868e2..181a177 100644 --- a/common/waifu2x.h +++ b/common/waifu2x.h @@ -89,9 +89,6 @@ private: int mMaxNetOffset; // ネットに入力するとどれくらい削れるか bool mHasNoiseScale; - float *mInputBlock; - size_t mInputBlockSize; - float *mOutputBlock; size_t mOutputBlockSize;