noise_scaleに対応

This commit is contained in:
lltcggie 2016-07-03 17:13:02 +09:00
parent ebc71b317c
commit 1e834f2c7b
13 changed files with 356 additions and 710 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,187 +0,0 @@
name: "srcnn"
layer {
name: "input"
type: "Input"
top: "input"
input_param { shape: { dim: 1 dim: 3 dim: 142 dim: 142 } }
}
layer {
name: "conv1_layer"
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv1_relu_layer"
type: "ReLU"
bottom: "conv1"
top: "conv1"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv2_layer"
type: "Convolution"
bottom: "conv1"
top: "conv2"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv2_relu_layer"
type: "ReLU"
bottom: "conv2"
top: "conv2"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv3_layer"
type: "Convolution"
bottom: "conv2"
top: "conv3"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv3_relu_layer"
type: "ReLU"
bottom: "conv3"
top: "conv3"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv4_layer"
type: "Convolution"
bottom: "conv3"
top: "conv4"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv4_relu_layer"
type: "ReLU"
bottom: "conv4"
top: "conv4"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv5_layer"
type: "Convolution"
bottom: "conv4"
top: "conv5"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv5_relu_layer"
type: "ReLU"
bottom: "conv5"
top: "conv5"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv6_layer"
type: "Convolution"
bottom: "conv5"
top: "conv6"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv6_relu_layer"
type: "ReLU"
bottom: "conv6"
top: "conv6"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv7_layer"
type: "Convolution"
bottom: "conv6"
top: "conv7"
convolution_param {
num_output: 3
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "target"
type: "MemoryData"
top: "target"
top: "dummy_label2"
memory_data_param {
batch_size: 1
channels: 1
height: 142
width: 142
}
include: { phase: TRAIN }
}
layer {
name: "loss"
type: "EuclideanLoss"
bottom: "conv7"
bottom: "target"
top: "loss"
include: { phase: TRAIN }
}

File diff suppressed because one or more lines are too long

View File

@ -1,187 +0,0 @@
name: "srcnn"
layer {
name: "input"
type: "Input"
top: "input"
input_param { shape: { dim: 1 dim: 3 dim: 142 dim: 142 } }
}
layer {
name: "conv1_layer"
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv1_relu_layer"
type: "ReLU"
bottom: "conv1"
top: "conv1"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv2_layer"
type: "Convolution"
bottom: "conv1"
top: "conv2"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv2_relu_layer"
type: "ReLU"
bottom: "conv2"
top: "conv2"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv3_layer"
type: "Convolution"
bottom: "conv2"
top: "conv3"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv3_relu_layer"
type: "ReLU"
bottom: "conv3"
top: "conv3"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv4_layer"
type: "Convolution"
bottom: "conv3"
top: "conv4"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv4_relu_layer"
type: "ReLU"
bottom: "conv4"
top: "conv4"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv5_layer"
type: "Convolution"
bottom: "conv4"
top: "conv5"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv5_relu_layer"
type: "ReLU"
bottom: "conv5"
top: "conv5"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv6_layer"
type: "Convolution"
bottom: "conv5"
top: "conv6"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv6_relu_layer"
type: "ReLU"
bottom: "conv6"
top: "conv6"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv7_layer"
type: "Convolution"
bottom: "conv6"
top: "conv7"
convolution_param {
num_output: 3
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "target"
type: "MemoryData"
top: "target"
top: "dummy_label2"
memory_data_param {
batch_size: 1
channels: 1
height: 142
width: 142
}
include: { phase: TRAIN }
}
layer {
name: "loss"
type: "EuclideanLoss"
bottom: "conv7"
bottom: "target"
top: "loss"
include: { phase: TRAIN }
}

File diff suppressed because one or more lines are too long

View File

@ -1,187 +0,0 @@
name: "srcnn"
layer {
name: "input"
type: "Input"
top: "input"
input_param { shape: { dim: 1 dim: 3 dim: 142 dim: 142 } }
}
layer {
name: "conv1_layer"
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv1_relu_layer"
type: "ReLU"
bottom: "conv1"
top: "conv1"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv2_layer"
type: "Convolution"
bottom: "conv1"
top: "conv2"
convolution_param {
num_output: 32
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv2_relu_layer"
type: "ReLU"
bottom: "conv2"
top: "conv2"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv3_layer"
type: "Convolution"
bottom: "conv2"
top: "conv3"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv3_relu_layer"
type: "ReLU"
bottom: "conv3"
top: "conv3"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv4_layer"
type: "Convolution"
bottom: "conv3"
top: "conv4"
convolution_param {
num_output: 64
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv4_relu_layer"
type: "ReLU"
bottom: "conv4"
top: "conv4"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv5_layer"
type: "Convolution"
bottom: "conv4"
top: "conv5"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv5_relu_layer"
type: "ReLU"
bottom: "conv5"
top: "conv5"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv6_layer"
type: "Convolution"
bottom: "conv5"
top: "conv6"
convolution_param {
num_output: 128
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "conv6_relu_layer"
type: "ReLU"
bottom: "conv6"
top: "conv6"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "conv7_layer"
type: "Convolution"
bottom: "conv6"
top: "conv7"
convolution_param {
num_output: 3
kernel_size: 3
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
}
}
layer {
name: "target"
type: "MemoryData"
top: "target"
top: "dummy_label2"
memory_data_param {
batch_size: 1
channels: 1
height: 142
width: 142
}
include: { phase: TRAIN }
}
layer {
name: "loss"
type: "EuclideanLoss"
bottom: "conv7"
bottom: "target"
top: "loss"
include: { phase: TRAIN }
}

View File

@ -128,21 +128,146 @@ static Waifu2x::eWaifu2xError readProtoBinary(const boost::filesystem::path &pat
return Waifu2x::eWaifu2xError_OK; return Waifu2x::eWaifu2xError_OK;
} }
cNet::cNet() : mModelScale(0), mInnerScale(0), mNetOffset(0), mInputPlane(0) namespace
{
Waifu2x::eWaifu2xError ReadJson(const boost::filesystem::path &info_path, rapidjson::Document &d, std::vector<char> &jsonBuf)
{
try
{
boost::iostreams::stream<boost::iostreams::file_descriptor_source> is;
try
{
is.open(info_path, std::ios_base::in | std::ios_base::binary);
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
}
if (!is)
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
const size_t size = is.seekg(0, std::ios::end).tellg();
is.seekg(0, std::ios::beg);
jsonBuf.resize(size + 1);
is.read(jsonBuf.data(), jsonBuf.size());
jsonBuf[jsonBuf.size() - 1] = '\0';
d.Parse(jsonBuf.data());
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedParseModelFile;
}
return Waifu2x::eWaifu2xError_OK;
}
};
cNet::cNet() : mModelScale(0), mInnerScale(0), mNetOffset(0), mInputPlane(0), mHasNoiseScaleModel(false)
{} {}
cNet::~cNet() cNet::~cNet()
{} {}
Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path, stInfo &info)
{
rapidjson::Document d;
std::vector<char> jsonBuf;
try
{
Waifu2x::eWaifu2xError ret;
ret = ReadJson(info_path, d, jsonBuf);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
const auto name = d["name"].GetString();
const auto arch_name = d["arch_name"].GetString();
const bool has_noise_scale = d.HasMember("has_noise_scale") && d["has_noise_scale"].GetBool() ? true : false;
const int channels = d["channels"].GetInt();
info.name = name;
info.arch_name = arch_name;
info.has_noise_scale = has_noise_scale;
info.channels = channels;
if (d.HasMember("offset"))
{
const int offset = d["offset"].GetInt();
info.noise.offset = offset;
info.scale.offset = offset;
info.noise_scale.offset = offset;
}
if (d.HasMember("scale_factor"))
{
const int scale_factor = d["scale_factor"].GetInt();
info.noise.scale_factor = scale_factor;
info.scale.scale_factor = scale_factor;
info.noise_scale.scale_factor = scale_factor;
}
if (d.HasMember("offset_noise"))
{
const int offset = d["offset_noise"].GetInt();
info.noise.offset = offset;
}
if (d.HasMember("scale_factor_noise"))
{
const int scale_factor = d["scale_factor_noise"].GetInt();
info.noise.scale_factor = scale_factor;
}
if (d.HasMember("offset_scale"))
{
const int offset = d["offset_scale"].GetInt();
info.scale.offset = offset;
}
if (d.HasMember("scale_factor_scale"))
{
const int scale_factor = d["scale_factor_scale"].GetInt();
info.scale.scale_factor = scale_factor;
}
if (d.HasMember("offset_noise_scale"))
{
const int offset = d["offset_noise_scale"].GetInt();
info.noise_scale.offset = offset;
}
if (d.HasMember("scale_factor_noise_scale"))
{
const int scale_factor = d["scale_factor_noise_scale"].GetInt();
info.noise_scale.scale_factor = scale_factor;
}
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedParseModelFile;
}
return Waifu2x::eWaifu2xError_OK;
}
// モデルファイルからネットワークを構築 // モデルファイルからネットワークを構築
// processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する // processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する
Waifu2x::eWaifu2xError cNet::ConstractNet(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const boost::filesystem::path &info_path, const std::string &process) Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process)
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
ret = LoadInfoFromJson(info_path); mMode = mode;
if (ret != Waifu2x::eWaifu2xError_OK)
return ret; LoadParamFromInfo(mode, info);
boost::filesystem::path modelbin_path = model_path; boost::filesystem::path modelbin_path = model_path;
modelbin_path += ".protobin"; modelbin_path += ".protobin";
@ -184,75 +309,35 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const boost::filesystem::path &model_p
return Waifu2x::eWaifu2xError_OK; return Waifu2x::eWaifu2xError_OK;
} }
namespace void cNet::LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info)
{ {
Waifu2x::eWaifu2xError ReadJson(const boost::filesystem::path &info_path, rapidjson::Document &d, std::vector<char> &jsonBuf) mModelScale = 2; // TODO: 動的に設定するようにする
stInfo::stParam param;
switch (mode)
{ {
try case Waifu2x::eWaifu2xModelTypeNoise:
{ param = info.noise;
boost::iostreams::stream<boost::iostreams::file_descriptor_source> is; break;
try case Waifu2x::eWaifu2xModelTypeScale:
{ param = info.scale;
is.open(info_path, std::ios_base::in | std::ios_base::binary); break;
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
}
if (!is) case Waifu2x::eWaifu2xModelTypeNoiseScale:
return Waifu2x::eWaifu2xError_FailedOpenModelFile; param = info.noise_scale;
break;
const size_t size = is.seekg(0, std::ios::end).tellg(); case Waifu2x::eWaifu2xModelTypeAutoScale:
is.seekg(0, std::ios::beg); param = info.noise_scale;
break;
jsonBuf.resize(size + 1);
is.read(jsonBuf.data(), jsonBuf.size());
jsonBuf[jsonBuf.size() - 1] = '\0';
d.Parse(jsonBuf.data());
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedParseModelFile;
}
return Waifu2x::eWaifu2xError_OK;
}
};
Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &info_path)
{
rapidjson::Document d;
std::vector<char> jsonBuf;
try
{
Waifu2x::eWaifu2xError ret;
ret = ReadJson(info_path, d, jsonBuf);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
const bool resize = d.HasMember("resize") && d["resize"].GetBool() ? true : false;
const auto name = d["name"].GetString();
const int channels = d["channels"].GetInt();
const int net_offset = d["offset"].GetInt();
const int inner_scale = d["scale_factor"].GetInt();
mModelScale = 2; // TODO: 動的に設定するようにする
mInnerScale = inner_scale;
mNetOffset = net_offset;
mInputPlane = channels;
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedParseModelFile;
} }
return Waifu2x::eWaifu2xError_OK; mInnerScale = param.scale_factor;
mNetOffset = param.offset;
mInputPlane = info.channels;
mHasNoiseScaleModel = info.has_noise_scale;
} }
Waifu2x::eWaifu2xError cNet::SetParameter(caffe::NetParameter &param, const std::string &process) const Waifu2x::eWaifu2xError cNet::SetParameter(caffe::NetParameter &param, const std::string &process) const
@ -732,26 +817,12 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_
std::string cNet::GetModelName(const boost::filesystem::path &info_path) std::string cNet::GetModelName(const boost::filesystem::path &info_path)
{ {
rapidjson::Document d; Waifu2x::eWaifu2xError ret;
std::vector<char> jsonBuf;
std::string str;
try stInfo info;
{ ret = GetInfo(info_path, info);
Waifu2x::eWaifu2xError ret; if (ret != Waifu2x::eWaifu2xError_OK)
return std::string();
ret = ReadJson(info_path, d, jsonBuf); return info.name;
if (ret != Waifu2x::eWaifu2xError_OK)
return str;
const auto name = d["name"].GetString();
str = name;
}
catch (...)
{
}
return str;
} }

View File

@ -4,27 +4,50 @@
#include "waifu2x.h" #include "waifu2x.h"
struct stInfo
{
struct stParam
{
int scale_factor;
int offset;
};
std::string name;
std::string arch_name;
bool has_noise_scale;
int channels;
stParam noise;
stParam scale;
stParam noise_scale;
};
class cNet class cNet
{ {
private: private:
Waifu2x::eWaifu2xModelType mMode;
boost::shared_ptr<caffe::Net<float>> mNet; boost::shared_ptr<caffe::Net<float>> mNet;
int mModelScale; // モデルが対象とする拡大率 int mModelScale; // モデルが対象とする拡大率
int mInnerScale; // ネット内部で拡大される倍率 int mInnerScale; // ネット内部で拡大される倍率
int mNetOffset; // ネットに入力するとどれくらい削れるか int mNetOffset; // ネットに入力するとどれくらい削れるか
int mInputPlane; // ネットへの入力チャンネル数 int mInputPlane; // ネットへの入力チャンネル数
bool mHasNoiseScaleModel;
private: private:
void LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info);
Waifu2x::eWaifu2xError LoadParameterFromJson(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path Waifu2x::eWaifu2xError LoadParameterFromJson(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path
, const boost::filesystem::path &modelbin_path, const boost::filesystem::path &caffemodel_path, const std::string &process); , const boost::filesystem::path &modelbin_path, const boost::filesystem::path &caffemodel_path, const std::string &process);
Waifu2x::eWaifu2xError LoadInfoFromJson(const boost::filesystem::path &info_path);
Waifu2x::eWaifu2xError SetParameter(caffe::NetParameter &param, const std::string &process) const; Waifu2x::eWaifu2xError SetParameter(caffe::NetParameter &param, const std::string &process) const;
public: public:
cNet(); cNet();
~cNet(); ~cNet();
Waifu2x::eWaifu2xError ConstractNet(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const boost::filesystem::path &info_path, const std::string &process); static Waifu2x::eWaifu2xError GetInfo(const boost::filesystem::path &info_path, stInfo &info);
Waifu2x::eWaifu2xError ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process);
int GetInputPlane() const; int GetInputPlane() const;
int GetInnerScale() const; int GetInnerScale() const;

View File

@ -633,14 +633,14 @@ void stImage::ShrinkImage(const double scale)
const double shrinkRatio = scale >= 1.0 ? scale / std::pow(scaleBase, scaleNum) : scale; const double shrinkRatio = scale >= 1.0 ? scale / std::pow(scaleBase, scaleNum) : scale;
const cv::Size_<int> ns(mOrgSize.width * scale, mOrgSize.height * scale); const cv::Size_<int> ns(mOrgSize.width * scale, mOrgSize.height * scale);
//if (mEndImage.size().width != ns.width || mEndImage.size().height != ns.height) if (mEndImage.size().width != ns.width || mEndImage.size().height != ns.height)
//{ {
// int argo = cv::INTER_CUBIC; int argo = cv::INTER_CUBIC;
// if (scale < 0.5) if (scale < 0.5)
// argo = cv::INTER_AREA; argo = cv::INTER_AREA;
// cv::resize(mEndImage, mEndImage, ns, 0.0, 0.0, argo); cv::resize(mEndImage, mEndImage, ns, 0.0, 0.0, argo);
//} }
} }
cv::Mat stImage::DeconvertFromFloat(const cv::Mat &im, const int depth) cv::Mat stImage::DeconvertFromFloat(const cv::Mat &im, const int depth)

View File

@ -257,7 +257,7 @@ Waifu2x::~Waifu2x()
Destroy(); Destroy();
} }
Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_level, Waifu2x::eWaifu2xError Waifu2x::Init(const eWaifu2xModelType mode, const int noise_level,
const boost::filesystem::path &model_dir, const std::string &process) const boost::filesystem::path &model_dir, const std::string &process)
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
@ -303,28 +303,47 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
mInputPlane = 0; mInputPlane = 0;
mMaxNetOffset = 0; mMaxNetOffset = 0;
// TODO: ノイズ除去と拡大を同時に行うネットワークへの対処を考える
const boost::filesystem::path info_path = GetInfoPath(mode_dir_path); const boost::filesystem::path info_path = GetInfoPath(mode_dir_path);
if (mode == "noise" || mode == "noise_scale" || mode == "auto_scale") stInfo info;
ret = cNet::GetInfo(info_path, info);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
mHasNoiseScale = info.has_noise_scale;
mInputPlane = info.channels;
if (mode == eWaifu2xModelTypeNoise || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
{ {
const std::string base_name = "noise" + std::to_string(noise_level) + "_model"; std::string base_name;
mNoiseNet.reset(new cNet);
eWaifu2xModelType Mode = mode;
if (info.has_noise_scale) // ノイズ除去と拡大を同時に行う
{
// イズ除去拡大ネットの構築はeWaifu2xModelTypeNoiseScaleを指定する必要がある
Mode = eWaifu2xModelTypeNoiseScale;
base_name = "noise" + std::to_string(noise_level) + "_scale2.0x_model";
}
else // ノイズ除去だけ
{
Mode = eWaifu2xModelTypeNoise;
base_name = "noise" + std::to_string(noise_level) + "_model";
}
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt"); const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json"); const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
mNoiseNet.reset(new cNet); ret = mNoiseNet->ConstractNet(Mode, model_path, param_path, info, mProcess);
ret = mNoiseNet->ConstractNet(model_path, param_path, info_path, mProcess);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
mInputPlane = mNoiseNet->GetInputPlane();
mMaxNetOffset = mNoiseNet->GetNetOffset(); mMaxNetOffset = mNoiseNet->GetNetOffset();
} }
if (mode == "scale" || mode == "noise_scale" || mode == "auto_scale") // noise_scaleを持っている場合はαチャンネルの拡大のためにmScaleNetも構築する必要がある
if (info.has_noise_scale || mode == eWaifu2xModelTypeScale || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
{ {
const std::string base_name = "scale2.0x_model"; const std::string base_name = "scale2.0x_model";
@ -333,13 +352,12 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
mScaleNet.reset(new cNet); mScaleNet.reset(new cNet);
ret = mScaleNet->ConstractNet(model_path, param_path, info_path, mProcess); ret = mScaleNet->ConstractNet(eWaifu2xModelTypeScale, model_path, param_path, info, mProcess);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
assert(mInputPlane == 0 || mInputPlane == mScaleNet->GetInputPlane()); assert(mInputPlane == 0 || mInputPlane == mScaleNet->GetInputPlane());
mInputPlane = mScaleNet->GetInputPlane();
mMaxNetOffset = std::max(mScaleNet->GetNetOffset(), mMaxNetOffset); mMaxNetOffset = std::max(mScaleNet->GetNetOffset(), mMaxNetOffset);
} }
else else
@ -399,15 +417,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_fil
image.Preprocess(mInputPlane, mMaxNetOffset); image.Preprocess(mInputPlane, mMaxNetOffset);
const bool isReconstructNoise = mMode == "noise" || mMode == "noise_scale" || (mMode == "auto_scale" && image.RequestDenoise()); const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale || (mMode == eWaifu2xModelTypeAutoScale && image.RequestDenoise());
const bool isReconstructScale = mMode == "scale" || mMode == "noise_scale" || mMode == "auto_scale"; const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
double Factor = factor;
if (!isReconstructScale)
Factor = 1.0;
cv::Mat reconstruct_image; cv::Mat reconstruct_image;
ret = ReconstructImage(factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, cancel_func, image); ret = ReconstructImage(Factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, cancel_func, image);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
image.Postprocess(mInputPlane, factor, output_depth); image.Postprocess(mInputPlane, Factor, output_depth);
ret = image.Save(output_file, output_quality); ret = image.Save(output_file, output_quality);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
@ -424,6 +446,7 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const double factor, const void* source,
if (!mIsInited) if (!mIsInited)
return Waifu2x::eWaifu2xError_NotInitialized; return Waifu2x::eWaifu2xError_NotInitialized;
stImage image; stImage image;
ret = image.Load(source, width, height, in_channel, in_stride); ret = image.Load(source, width, height, in_channel, in_stride);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
@ -431,15 +454,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const double factor, const void* source,
image.Preprocess(mInputPlane, mMaxNetOffset); image.Preprocess(mInputPlane, mMaxNetOffset);
const bool isReconstructNoise = mMode == "noise" || mMode == "noise_scale"; const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale;
const bool isReconstructScale = mMode == "scale" || mMode == "noise_scale" || mMode == "auto_scale"; const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
double Factor = factor;
if (!isReconstructScale)
Factor = 1.0;
cv::Mat reconstruct_image; cv::Mat reconstruct_image;
ret = ReconstructImage(factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, nullptr, image); ret = ReconstructImage(Factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, nullptr, image);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
image.Postprocess(mInputPlane, factor, 8); image.Postprocess(mInputPlane, Factor, 8);
cv::Mat out_image = image.GetEndImage(); cv::Mat out_image = image.GetEndImage();
image.Clear(); image.Clear();
@ -460,29 +487,39 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(const double factor, const int
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
// TODO: ノイズ除去と拡大を同時に行うネットワークへの対処を考える double Factor = factor;
if (isReconstructNoise) if (isReconstructNoise)
{ {
cv::Mat im; if (!mHasNoiseScale) // ノイズ除去だけ
cv::Size_<int> size; {
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, 1); cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, 1);
ret = ProcessNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, im); ret = ProcessNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, im);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return ret; return ret;
image.SetReconstructedRGB(im, size, 1); image.SetReconstructedRGB(im, size, 1);
}
else // ノイズ除去と拡大
{
ret = ReconstructNoiseScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
Factor /= mNoiseNet->GetInnerScale();
}
} }
if (cancel_func && cancel_func()) if (cancel_func && cancel_func())
return Waifu2x::eWaifu2xError_Cancel; return Waifu2x::eWaifu2xError_Cancel;
const int scaleNum = ceil(log(factor) / log(ScaleBase)); const int scaleNum = ceil(log(Factor) / log(ScaleBase));
if (isReconstructScale) if (isReconstructScale)
{ {
bool isError = false;
for (int i = 0; i < scaleNum; i++) for (int i = 0; i < scaleNum; i++)
{ {
ret = ReconstructScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image); ret = ReconstructScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
@ -525,6 +562,40 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructScale(const int crop_w, const int cro
return Waifu2x::eWaifu2xError_OK; return Waifu2x::eWaifu2xError_OK;
} }
Waifu2x::eWaifu2xError Waifu2x::ReconstructNoiseScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
{
Waifu2x::eWaifu2xError ret;
if (image.HasAlpha())
{
// αチャンネルにはノイズ除去を行わない
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedA(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.SetReconstructedA(im, size, mScaleNet->GetInnerScale());
}
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mNoiseNet->GetScale() / mNoiseNet->GetInnerScale());
ret = ReconstructByNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.SetReconstructedRGB(im, size, mNoiseNet->GetInnerScale());
return Waifu2x::eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, Waifu2x::eWaifu2xError Waifu2x::ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im) const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im)
{ {

View File

@ -28,6 +28,14 @@ class stImage;
class Waifu2x class Waifu2x
{ {
public: public:
enum eWaifu2xModelType
{
eWaifu2xModelTypeNoise,
eWaifu2xModelTypeScale,
eWaifu2xModelTypeNoiseScale,
eWaifu2xModelTypeAutoScale,
};
enum eWaifu2xError enum eWaifu2xError
{ {
eWaifu2xError_OK = 0, eWaifu2xError_OK = 0,
@ -68,7 +76,7 @@ public:
private: private:
bool mIsInited; bool mIsInited;
std::string mMode; eWaifu2xModelType mMode;
int mNoiseLevel; int mNoiseLevel;
std::string mProcess; std::string mProcess;
@ -79,6 +87,7 @@ private:
int mInputPlane; // ネットへの入力チャンネル数 int mInputPlane; // ネットへの入力チャンネル数
int mMaxNetOffset; // ネットに入力するとどれくらい削れるか int mMaxNetOffset; // ネットに入力するとどれくらい削れるか
bool mHasNoiseScale;
float *mInputBlock; float *mInputBlock;
size_t mInputBlockSize; size_t mInputBlockSize;
@ -94,6 +103,8 @@ private:
const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image); const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image);
Waifu2x::eWaifu2xError ReconstructScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size, Waifu2x::eWaifu2xError ReconstructScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image); const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image);
Waifu2x::eWaifu2xError ReconstructNoiseScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image);
Waifu2x::eWaifu2xError ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, Waifu2x::eWaifu2xError ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im); const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im);
Waifu2x::eWaifu2xError ProcessNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, cv::Mat &im); Waifu2x::eWaifu2xError ProcessNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, cv::Mat &im);
@ -112,7 +123,7 @@ public:
// mode: noise or scale or noise_scale or auto_scale // mode: noise or scale or noise_scale or auto_scale
// process: cpu or gpu or cudnn // process: cpu or gpu or cudnn
eWaifu2xError Init(const std::string &mode, const int noise_level, eWaifu2xError Init(const eWaifu2xModelType mode, const int noise_level,
const boost::filesystem::path &model_dir, const std::string &process); const boost::filesystem::path &model_dir, const std::string &process);
eWaifu2xError waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file, eWaifu2xError waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file,

View File

@ -123,22 +123,32 @@ tstring DialogEvent::AddName() const
addstr += TEXT(")"); addstr += TEXT(")");
addstr += TEXT("("); addstr += TEXT("(");
if (mode == "noise") switch (mode)
{
case Waifu2x::eWaifu2xModelTypeNoise:
addstr += TEXT("noise"); addstr += TEXT("noise");
else if (mode == "scale") break;
case Waifu2x::eWaifu2xModelTypeScale:
addstr += TEXT("scale"); addstr += TEXT("scale");
else if (mode == "noise_scale") break;
case Waifu2x::eWaifu2xModelTypeNoiseScale:
addstr += TEXT("noise_scale"); addstr += TEXT("noise_scale");
else if (mode == "auto_scale") break;
case Waifu2x::eWaifu2xModelTypeAutoScale:
addstr += TEXT("auto_scale"); addstr += TEXT("auto_scale");
break;
}
addstr += TEXT(")"); addstr += TEXT(")");
if (mode.find("noise") != mode.npos || mode.find("auto_scale") != mode.npos) if (mode == Waifu2x::eWaifu2xModelTypeNoise || mode == Waifu2x::eWaifu2xModelTypeNoiseScale || mode == Waifu2x::eWaifu2xModelTypeAutoScale)
addstr += TEXT("(Level") + to_tstring(noise_level) + TEXT(")"); addstr += TEXT("(Level") + to_tstring(noise_level) + TEXT(")");
if (use_tta) if (use_tta)
addstr += TEXT("(tta)"); addstr += TEXT("(tta)");
if (mode.find("scale") != mode.npos) if (mode == Waifu2x::eWaifu2xModelTypeScale || mode == Waifu2x::eWaifu2xModelTypeNoiseScale || mode == Waifu2x::eWaifu2xModelTypeAutoScale)
{ {
if (scaleType == eScaleTypeRatio) if (scaleType == eScaleTypeRatio)
addstr += TEXT("(x") + to_tstring(scale_ratio) + TEXT(")"); addstr += TEXT("(x") + to_tstring(scale_ratio) + TEXT(")");
@ -176,13 +186,25 @@ bool DialogEvent::SyncMember(const bool NotSyncCropSize, const bool silent)
} }
if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_NOISE), BM_GETCHECK, 0, 0)) if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_NOISE), BM_GETCHECK, 0, 0))
mode = "noise"; {
mode = Waifu2x::eWaifu2xModelTypeNoise;
modeStr = "noise";
}
else if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_SCALE), BM_GETCHECK, 0, 0)) else if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_SCALE), BM_GETCHECK, 0, 0))
mode = "scale"; {
mode = Waifu2x::eWaifu2xModelTypeScale;
modeStr = "scale";
}
else if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_NOISE_SCALE), BM_GETCHECK, 0, 0)) else if (SendMessage(GetDlgItem(dh, IDC_RADIO_MODE_NOISE_SCALE), BM_GETCHECK, 0, 0))
mode = "noise_scale"; {
mode = Waifu2x::eWaifu2xModelTypeNoiseScale;
modeStr = "noise_scale";
}
else else
mode = "auto_scale"; {
mode = Waifu2x::eWaifu2xModelTypeAutoScale;
modeStr = "auto_scale";
}
if (SendMessage(GetDlgItem(dh, IDC_RADIONOISE_LEVEL1), BM_GETCHECK, 0, 0)) if (SendMessage(GetDlgItem(dh, IDC_RADIONOISE_LEVEL1), BM_GETCHECK, 0, 0))
noise_level = 1; noise_level = 1;
@ -895,14 +917,24 @@ void DialogEvent::SaveIni(const bool isSyncMember)
else else
tScaleHeight = TEXT(""); tScaleHeight = TEXT("");
if (mode == ("noise")) switch (mode)
{
case Waifu2x::eWaifu2xModelTypeNoise:
tmode = TEXT("noise"); tmode = TEXT("noise");
else if (mode == ("scale")) break;
case Waifu2x::eWaifu2xModelTypeScale:
tmode = TEXT("scale"); tmode = TEXT("scale");
else if (mode == ("auto_scale")) break;
tmode = TEXT("auto_scale");
else // noise_scale case Waifu2x::eWaifu2xModelTypeNoiseScale:
tmode = TEXT("noise_scale"); tmode = TEXT("noise_scale");
break;
case Waifu2x::eWaifu2xModelTypeAutoScale:
tmode = TEXT("auto_scale");
break;
}
if (process == "gpu") if (process == "gpu")
tprcess = TEXT("gpu"); tprcess = TEXT("gpu");
@ -1142,7 +1174,7 @@ UINT_PTR DialogEvent::OFNHookProcOut(HWND hdlg, UINT uiMsg, WPARAM wParam, LPARA
return 0L; return 0L;
} }
DialogEvent::DialogEvent() : dh(nullptr), mode("noise_scale"), noise_level(1), scale_ratio(2.0), scale_width(0), scale_height(0), model_dir(TEXT("models/anime_style_art_rgb")), DialogEvent::DialogEvent() : dh(nullptr), mode(Waifu2x::eWaifu2xModelTypeNoiseScale), modeStr("noise_scale"), noise_level(1), scale_ratio(2.0), scale_width(0), scale_height(0), model_dir(TEXT("models/anime_style_art_rgb")),
process("gpu"), outputExt(TEXT(".png")), inputFileExt(TEXT("png:jpg:jpeg:tif:tiff:bmp:tga")), process("gpu"), outputExt(TEXT(".png")), inputFileExt(TEXT("png:jpg:jpeg:tif:tiff:bmp:tga")),
use_tta(false), output_depth(8), crop_size(128), batch_size(1), isLastError(false), scaleType(eScaleTypeEnd), use_tta(false), output_depth(8), crop_size(128), batch_size(1), isLastError(false), scaleType(eScaleTypeEnd),
TimeLeftThread(-1), TimeLeftGetTimeThread(0), isCommandLineStart(false), tAutoMode(TEXT("none")), TimeLeftThread(-1), TimeLeftGetTimeThread(0), isCommandLineStart(false), tAutoMode(TEXT("none")),
@ -1423,8 +1455,8 @@ void DialogEvent::SetWindowTextLang()
SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_RGB").c_str()); SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_RGB").c_str());
SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_PHOTO").c_str()); SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_PHOTO").c_str());
SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_Y").c_str());
SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_UPCONV_RGB").c_str()); SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_UPCONV_RGB").c_str());
SendMessage(hwndCombo, CB_ADDSTRING, 0, (LPARAM)langStringList.GetString(L"IDC_RADIO_MODEL_Y").c_str());
SendMessage(GetDlgItem(dh, IDC_COMBO_MODEL), CB_SETCURSEL, cur, 0); SendMessage(GetDlgItem(dh, IDC_COMBO_MODEL), CB_SETCURSEL, cur, 0);
} }

View File

@ -9,6 +9,7 @@
#include <atomic> #include <atomic>
#include <boost/filesystem.hpp> #include <boost/filesystem.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include "../common/waifu2x.h"
#include "resource.h" #include "resource.h"
#include "tstring.h" #include "tstring.h"
#include "LangStringList.h" #include "LangStringList.h"
@ -33,7 +34,8 @@ private:
tstring input_str; tstring input_str;
std::vector<tstring> input_str_multi; std::vector<tstring> input_str_multi;
tstring output_str; tstring output_str;
std::string mode; std::string modeStr;
Waifu2x::eWaifu2xModelType mode;
int noise_level; int noise_level;
double scale_ratio; double scale_ratio;
int scale_width; int scale_width;