crop_sizeが128より小さい時に強制終了するバグを修正

This commit is contained in:
lltcggie 2016-07-03 20:56:00 +09:00
parent 1e834f2c7b
commit 04a9fe915d
4 changed files with 6 additions and 27 deletions

View File

@ -596,7 +596,7 @@ int cNet::GetOutputMemorySize(const int crop_w, const int crop_h, const int oute
} }
// ネットワークを使って画像を再構築する // ネットワークを使って画像を再構築する
Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat) Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat)
{ {
const auto InputHeight = inMat.size().height; const auto InputHeight = inMat.size().height;
const auto InputWidth = inMat.size().width; const auto InputWidth = inMat.size().width;
@ -672,7 +672,7 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_
// 画像を直列に変換 // 画像を直列に変換
{ {
float *fptr = inputBlockBuf + (input_block_plane_size * n); float *fptr = input_blob->mutable_cpu_data() + (input_block_plane_size * n);
const float *uptr = (const float *)someimg.data; const float *uptr = (const float *)someimg.data;
const auto Line = someimg.step1(); const auto Line = someimg.step1();
@ -712,9 +712,6 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_
assert(input_blob->count() == input_block_plane_size * processNum); assert(input_blob->count() == input_block_plane_size * processNum);
// ネットワークに画像を入力
input_blob->set_cpu_data(inputBlockBuf);
// 計算 // 計算
auto out = mNet->Forward(); auto out = mNet->Forward();

View File

@ -57,7 +57,7 @@ public:
int GetInputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const; int GetInputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const;
int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const; int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const;
Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat); Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat);
static std::string GetModelName(const boost::filesystem::path &info_path); static std::string GetModelName(const boost::filesystem::path &info_path);
}; };

View File

@ -249,7 +249,7 @@ void Waifu2x::quit_liblary()
{} {}
Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mInputBlock(nullptr), mInputBlockSize(0), mOutputBlock(nullptr), mOutputBlockSize(0) Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mOutputBlock(nullptr), mOutputBlockSize(0)
{} {}
Waifu2x::~Waifu2x() Waifu2x::~Waifu2x()
@ -669,19 +669,6 @@ Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr<cNet> net, const int
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
const auto InputMemorySize = net->GetInputMemorySize(crop_w, crop_h, OuterPadding, batch_size);
if (InputMemorySize > mInputBlockSize)
{
if (mIsCuda)
CUDA_HOST_SAFE_FREE(mInputBlock);
else
SAFE_DELETE_WAIFU2X(mInputBlock);
CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mInputBlock, InputMemorySize, cudaHostAllocWriteCombined));
mInputBlockSize = InputMemorySize;
}
const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size); const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size);
if (OutputMemorySize > mOutputBlockSize) if (OutputMemorySize > mOutputBlockSize)
{ {
@ -692,10 +679,10 @@ Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr<cNet> net, const int
CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault)); CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault));
mInputBlockSize = OutputMemorySize; mOutputBlockSize = OutputMemorySize;
} }
ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mInputBlock, mOutputBlock, im, im); ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mOutputBlock, im, im);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
@ -720,12 +707,10 @@ void Waifu2x::Destroy()
if (mIsCuda) if (mIsCuda)
{ {
CUDA_HOST_SAFE_FREE(mInputBlock);
CUDA_HOST_SAFE_FREE(mOutputBlock); CUDA_HOST_SAFE_FREE(mOutputBlock);
} }
else else
{ {
SAFE_DELETE_WAIFU2X(mInputBlock);
SAFE_DELETE_WAIFU2X(mOutputBlock); SAFE_DELETE_WAIFU2X(mOutputBlock);
} }

View File

@ -89,9 +89,6 @@ private:
int mMaxNetOffset; // ネットに入力するとどれくらい削れるか int mMaxNetOffset; // ネットに入力するとどれくらい削れるか
bool mHasNoiseScale; bool mHasNoiseScale;
float *mInputBlock;
size_t mInputBlockSize;
float *mOutputBlock; float *mOutputBlock;
size_t mOutputBlockSize; size_t mOutputBlockSize;