mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-28 14:42:48 +00:00
Test-Time Augmentation Mode実装
This commit is contained in:
parent
9d32a969bf
commit
fb6f80970a
@ -815,7 +815,7 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(boost::shared_ptr<caffe::Net<fl
|
|||||||
}
|
}
|
||||||
|
|
||||||
Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &Mode, const int NoiseLevel, const double ScaleRatio, const std::string &ModelDir, const std::string &Process,
|
Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &Mode, const int NoiseLevel, const double ScaleRatio, const std::string &ModelDir, const std::string &Process,
|
||||||
const int CropSize, const int BatchSize)
|
const bool UseTTA, const int CropSize, const int BatchSize)
|
||||||
{
|
{
|
||||||
Waifu2x::eWaifu2xError ret;
|
Waifu2x::eWaifu2xError ret;
|
||||||
|
|
||||||
@ -832,6 +832,7 @@ Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &M
|
|||||||
scale_ratio = ScaleRatio;
|
scale_ratio = ScaleRatio;
|
||||||
model_dir = ModelDir;
|
model_dir = ModelDir;
|
||||||
process = Process;
|
process = Process;
|
||||||
|
use_tta = UseTTA;
|
||||||
|
|
||||||
crop_size = CropSize;
|
crop_size = CropSize;
|
||||||
batch_size = BatchSize;
|
batch_size = BatchSize;
|
||||||
@ -1028,29 +1029,19 @@ Waifu2x::eWaifu2xError Waifu2x::WriteMat(const cv::Mat &im, const std::string &o
|
|||||||
return eWaifu2xError_FailedOpenOutputFile;
|
return eWaifu2xError_FailedOpenOutputFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
|
Waifu2x::eWaifu2xError Waifu2x::BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out)
|
||||||
const waifu2xCancelFunc cancel_func)
|
|
||||||
{
|
{
|
||||||
Waifu2x::eWaifu2xError ret;
|
Waifu2x::eWaifu2xError ret;
|
||||||
|
|
||||||
if (!is_inited)
|
|
||||||
return eWaifu2xError_NotInitialized;
|
|
||||||
|
|
||||||
cv::Mat float_image;
|
|
||||||
ret = LoadMat(float_image, input_file);
|
|
||||||
if (ret != eWaifu2xError_OK)
|
|
||||||
return ret;
|
|
||||||
|
|
||||||
cv::Mat im;
|
cv::Mat im;
|
||||||
if (input_plane == 1)
|
if (input_plane == 1)
|
||||||
CreateBrightnessImage(float_image, im);
|
CreateBrightnessImage(in, im);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
||||||
std::vector<cv::Mat> planes;
|
std::vector<cv::Mat> planes;
|
||||||
cv::split(float_image, planes);
|
cv::split(in, planes);
|
||||||
|
|
||||||
if (float_image.channels() == 4)
|
if (in.channels() == 4)
|
||||||
planes.resize(3);
|
planes.resize(3);
|
||||||
|
|
||||||
// BGRからRGBにする
|
// BGRからRGBにする
|
||||||
@ -1058,13 +1049,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
|
|
||||||
cv::merge(planes, im);
|
cv::merge(planes, im);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out = im;
|
||||||
|
|
||||||
|
return eWaifu2xError_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Waifu2x::eWaifu2xError Waifu2x::ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out)
|
||||||
|
{
|
||||||
|
Waifu2x::eWaifu2xError ret;
|
||||||
|
|
||||||
|
cv::Mat im(in);
|
||||||
cv::Size_<int> image_size = im.size();
|
cv::Size_<int> image_size = im.size();
|
||||||
|
|
||||||
const boost::filesystem::path ip(input_file);
|
|
||||||
const boost::filesystem::path ipext(ip.extension());
|
|
||||||
|
|
||||||
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
|
|
||||||
|
|
||||||
const bool isReconstructNoise = mode == "noise" || mode == "noise_scale" || (mode == "auto_scale" && isJpeg);
|
const bool isReconstructNoise = mode == "noise" || mode == "noise_scale" || (mode == "auto_scale" && isJpeg);
|
||||||
const bool isReconstructScale = mode == "scale" || mode == "noise_scale";
|
const bool isReconstructScale = mode == "scale" || mode == "noise_scale";
|
||||||
|
|
||||||
@ -1084,7 +1081,6 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
return eWaifu2xError_Cancel;
|
return eWaifu2xError_Cancel;
|
||||||
|
|
||||||
const int scale2 = ceil(log2(scale_ratio));
|
const int scale2 = ceil(log2(scale_ratio));
|
||||||
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
|
|
||||||
|
|
||||||
if (isReconstructScale)
|
if (isReconstructScale)
|
||||||
{
|
{
|
||||||
@ -1105,18 +1101,24 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
if (cancel_func && cancel_func())
|
if (cancel_func && cancel_func())
|
||||||
return eWaifu2xError_Cancel;
|
return eWaifu2xError_Cancel;
|
||||||
|
|
||||||
|
out = im;
|
||||||
|
|
||||||
|
return eWaifu2xError_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Waifu2x::eWaifu2xError Waifu2x::AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out)
|
||||||
|
{
|
||||||
|
cv::Size_<int> image_size = in.size();
|
||||||
|
|
||||||
cv::Mat process_image;
|
cv::Mat process_image;
|
||||||
if (input_plane == 1)
|
if (input_plane == 1)
|
||||||
{
|
{
|
||||||
// 再構築した輝度画像とCreateZoomColorImage()で作成した色情報をマージして通常の画像に変換し、書き込む
|
// 再構築した輝度画像とCreateZoomColorImage()で作成した色情報をマージして通常の画像に変換し、書き込む
|
||||||
|
|
||||||
std::vector<cv::Mat> color_planes;
|
std::vector<cv::Mat> color_planes;
|
||||||
CreateZoomColorImage(float_image, image_size, color_planes);
|
CreateZoomColorImage(floatim, image_size, color_planes);
|
||||||
|
|
||||||
float_image.release();
|
color_planes[0] = in;
|
||||||
|
|
||||||
color_planes[0] = im;
|
|
||||||
im.release();
|
|
||||||
|
|
||||||
cv::Mat converted_image;
|
cv::Mat converted_image;
|
||||||
cv::merge(color_planes, converted_image);
|
cv::merge(color_planes, converted_image);
|
||||||
@ -1128,7 +1130,7 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
std::vector<cv::Mat> planes;
|
std::vector<cv::Mat> planes;
|
||||||
cv::split(im, planes);
|
cv::split(in, planes);
|
||||||
|
|
||||||
// RGBからBGRに直す
|
// RGBからBGRに直す
|
||||||
std::swap(planes[0], planes[2]);
|
std::swap(planes[0], planes[2]);
|
||||||
@ -1137,10 +1139,10 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
}
|
}
|
||||||
|
|
||||||
cv::Mat alpha;
|
cv::Mat alpha;
|
||||||
if (float_image.channels() == 4)
|
if (floatim.channels() == 4)
|
||||||
{
|
{
|
||||||
std::vector<cv::Mat> planes;
|
std::vector<cv::Mat> planes;
|
||||||
cv::split(float_image, planes);
|
cv::split(floatim, planes);
|
||||||
alpha = planes[3];
|
alpha = planes[3];
|
||||||
|
|
||||||
cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC);
|
cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC);
|
||||||
@ -1164,10 +1166,117 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
|||||||
cv::merge(planes, process_image);
|
cv::merge(planes, process_image);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int scale2 = ceil(log2(scale_ratio));
|
||||||
|
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
|
||||||
|
|
||||||
const cv::Size_<int> ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio);
|
const cv::Size_<int> ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio);
|
||||||
if (image_size.width != ns.width || image_size.height != ns.height)
|
if (image_size.width != ns.width || image_size.height != ns.height)
|
||||||
cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR);
|
cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR);
|
||||||
|
|
||||||
|
out = process_image;
|
||||||
|
|
||||||
|
return eWaifu2xError_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
|
||||||
|
const waifu2xCancelFunc cancel_func)
|
||||||
|
{
|
||||||
|
Waifu2x::eWaifu2xError ret;
|
||||||
|
|
||||||
|
if (!is_inited)
|
||||||
|
return eWaifu2xError_NotInitialized;
|
||||||
|
|
||||||
|
const boost::filesystem::path ip(input_file);
|
||||||
|
const boost::filesystem::path ipext(ip.extension());
|
||||||
|
|
||||||
|
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
|
||||||
|
|
||||||
|
cv::Mat float_image;
|
||||||
|
ret = LoadMat(float_image, input_file);
|
||||||
|
if (ret != eWaifu2xError_OK)
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
cv::Mat brfm;
|
||||||
|
ret = BeforeReconstructFloatMatProcess(float_image, brfm);
|
||||||
|
if (ret != eWaifu2xError_OK)
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
cv::Mat reconstruct_image;
|
||||||
|
if (!use_tta) // 普通に処理
|
||||||
|
{
|
||||||
|
ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image);
|
||||||
|
if (ret != eWaifu2xError_OK)
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
else // Test-Time Augmentation Mode
|
||||||
|
{
|
||||||
|
const auto RotateClockwise90 = [](cv::Mat &mat)
|
||||||
|
{
|
||||||
|
cv::transpose(mat, mat);
|
||||||
|
cv::flip(mat, mat, 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < rotateNum; i++)
|
||||||
|
RotateClockwise90(mat);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto RotateCounterclockwise90 = [](cv::Mat &mat)
|
||||||
|
{
|
||||||
|
cv::transpose(mat, mat);
|
||||||
|
cv::flip(mat, mat, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < rotateNum; i++)
|
||||||
|
RotateCounterclockwise90(mat);
|
||||||
|
};
|
||||||
|
|
||||||
|
cv::Mat ri[8];
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
cv::Mat in(brfm.clone());
|
||||||
|
|
||||||
|
cv::imwrite("0.png", in * 255.0);
|
||||||
|
|
||||||
|
const int rotateNum = i % 4;
|
||||||
|
RotateClockwise90N(in, rotateNum);
|
||||||
|
cv::imwrite("1.png", in * 255.0);
|
||||||
|
|
||||||
|
if(i >= 4)
|
||||||
|
cv::flip(in, in, 1); // 垂直軸反転
|
||||||
|
cv::imwrite("2.png", in * 255.0);
|
||||||
|
|
||||||
|
ret = ReconstructFloatMat(isJpeg, cancel_func, in, in);
|
||||||
|
if (ret != eWaifu2xError_OK)
|
||||||
|
return ret;
|
||||||
|
cv::imwrite("3.png", in * 255.0);
|
||||||
|
if (i >= 4)
|
||||||
|
cv::flip(in, in, 1); // 垂直軸反転
|
||||||
|
cv::imwrite("4.png", in * 255.0);
|
||||||
|
RotateCounterclockwise90N(in, rotateNum);
|
||||||
|
cv::imwrite("5.png", in * 255.0);
|
||||||
|
ri[i] = in;
|
||||||
|
}
|
||||||
|
|
||||||
|
reconstruct_image = ri[0];
|
||||||
|
for (int i = 1; i < 8; i++)
|
||||||
|
reconstruct_image += ri[i];
|
||||||
|
|
||||||
|
reconstruct_image /= 8.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
brfm.release();
|
||||||
|
|
||||||
|
cv::Mat process_image;
|
||||||
|
ret = AfterReconstructFloatMatProcess(float_image, reconstruct_image, process_image);
|
||||||
|
if (ret != eWaifu2xError_OK)
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
float_image.release();
|
||||||
|
|
||||||
cv::Mat write_iamge;
|
cv::Mat write_iamge;
|
||||||
process_image.convertTo(write_iamge, CV_8U, 255.0);
|
process_image.convertTo(write_iamge, CV_8U, 255.0);
|
||||||
process_image.release();
|
process_image.release();
|
||||||
|
@ -90,6 +90,8 @@ private:
|
|||||||
float *dummy_data;
|
float *dummy_data;
|
||||||
float *output_block;
|
float *output_block;
|
||||||
|
|
||||||
|
bool use_tta;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static eWaifu2xError LoadMat(cv::Mat &float_image, const std::string &input_file);
|
static eWaifu2xError LoadMat(cv::Mat &float_image, const std::string &input_file);
|
||||||
static eWaifu2xError LoadMatBySTBI(cv::Mat &float_image, const std::string &input_file);
|
static eWaifu2xError LoadMatBySTBI(cv::Mat &float_image, const std::string &input_file);
|
||||||
@ -103,6 +105,10 @@ private:
|
|||||||
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, cv::Mat &im);
|
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, cv::Mat &im);
|
||||||
eWaifu2xError WriteMat(const cv::Mat &im, const std::string &output_file);
|
eWaifu2xError WriteMat(const cv::Mat &im, const std::string &output_file);
|
||||||
|
|
||||||
|
eWaifu2xError BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out);
|
||||||
|
eWaifu2xError ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out);
|
||||||
|
eWaifu2xError AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Waifu2x();
|
Waifu2x();
|
||||||
~Waifu2x();
|
~Waifu2x();
|
||||||
@ -113,7 +119,7 @@ public:
|
|||||||
// mode: noise or scale or noise_scale or auto_scale
|
// mode: noise or scale or noise_scale or auto_scale
|
||||||
// process: cpu or gpu or cudnn
|
// process: cpu or gpu or cudnn
|
||||||
eWaifu2xError init(int argc, char** argv, const std::string &mode, const int noise_level, const double scale_ratio, const std::string &model_dir, const std::string &process,
|
eWaifu2xError init(int argc, char** argv, const std::string &mode, const int noise_level, const double scale_ratio, const std::string &model_dir, const std::string &process,
|
||||||
const int crop_size = 128, const int batch_size = 1);
|
const bool use_tta = false, const int crop_size = 128, const int batch_size = 1);
|
||||||
|
|
||||||
void destroy();
|
void destroy();
|
||||||
|
|
||||||
|
@ -107,6 +107,13 @@ int main(int argc, char** argv)
|
|||||||
"input batch size", false,
|
"input batch size", false,
|
||||||
1, "int", cmd);
|
1, "int", cmd);
|
||||||
|
|
||||||
|
std::vector<int> cmdTTAConstraintV;
|
||||||
|
cmdTTAConstraintV.push_back(0);
|
||||||
|
cmdTTAConstraintV.push_back(1);
|
||||||
|
TCLAP::ValuesConstraint<int> cmdTTAConstraint(cmdTTAConstraintV);
|
||||||
|
TCLAP::ValueArg<int> cmdTTALevel("t", "tta", "8x slower and slightly high quality",
|
||||||
|
false, 0, &cmdTTAConstraint, cmd);
|
||||||
|
|
||||||
// definition of command line argument : end
|
// definition of command line argument : end
|
||||||
|
|
||||||
TCLAP::Arg::enableIgnoreMismatched();
|
TCLAP::Arg::enableIgnoreMismatched();
|
||||||
@ -237,7 +244,7 @@ int main(int argc, char** argv)
|
|||||||
|
|
||||||
Waifu2x::eWaifu2xError ret;
|
Waifu2x::eWaifu2xError ret;
|
||||||
Waifu2x w;
|
Waifu2x w;
|
||||||
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(),
|
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(), cmdTTALevel.getValue() == 1,
|
||||||
cmdCropSizeFile.getValue(), cmdBatchSizeFile.getValue());
|
cmdCropSizeFile.getValue(), cmdBatchSizeFile.getValue());
|
||||||
switch (ret)
|
switch (ret)
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user