モデルのおすすめのCropSizeを設定する機能追加、UpResNet10モデルにおすすめのCropSize設定

CropSizeによって出力が変わるモデル用の設定
This commit is contained in:
lltcggie 2018-10-25 04:15:53 +09:00
parent a17ad89116
commit 6effe0dfaa
6 changed files with 59 additions and 29 deletions

View File

@ -1,3 +1,3 @@
{"name":"UpResNet10","arch_name":"upresnet10","has_noise_scale":true,"channels":3,
"scale_factor":2,"offset":26
"scale_factor":2,"offset":26,"recommended_crop_size":38
}

View File

@ -174,7 +174,7 @@ cNet::cNet() : mModelScale(0), mInnerScale(0), mNetOffset(0), mInputPlane(0), mH
cNet::~cNet()
{}
Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path, stInfo &info)
Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path, Waifu2x::stInfo &info)
{
rapidjson::Document d;
std::vector<char> jsonBuf;
@ -191,11 +191,13 @@ Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path,
const auto arch_name = d["arch_name"].GetString();
const bool has_noise_scale = d.HasMember("has_noise_scale") && d["has_noise_scale"].GetBool() ? true : false;
const int channels = d["channels"].GetInt();
const int recommended_crop_size = d.HasMember("recommended_crop_size") ? d["recommended_crop_size"].GetInt() : -1;
info.name = name;
info.arch_name = arch_name;
info.has_noise_scale = has_noise_scale;
info.channels = channels;
info.recommended_crop_size = recommended_crop_size;
if (d.HasMember("offset"))
{
@ -261,7 +263,7 @@ Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path,
// モデルファイルからネットワークを構築
// processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する
Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process)
Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const Waifu2x::stInfo &info, const std::string &process)
{
Waifu2x::eWaifu2xError ret;
@ -321,11 +323,11 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode,
return Waifu2x::eWaifu2xError_OK;
}
void cNet::LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info)
void cNet::LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const Waifu2x::stInfo &info)
{
mModelScale = 2; // TODO: 動的に設定するようにする
stInfo::stParam param;
Waifu2x::stInfo::stParam param;
switch (mode)
{
@ -824,7 +826,7 @@ std::string cNet::GetModelName(const boost::filesystem::path &info_path)
{
Waifu2x::eWaifu2xError ret;
stInfo info;
Waifu2x::stInfo info;
ret = GetInfo(info_path, info);
if (ret != Waifu2x::eWaifu2xError_OK)
return std::string();

View File

@ -4,24 +4,6 @@
#include "waifu2x.h"
struct stInfo
{
struct stParam
{
int scale_factor;
int offset;
};
std::string name;
std::string arch_name;
bool has_noise_scale;
int channels;
stParam noise;
stParam scale;
stParam noise_scale;
};
class cNet
{
private:
@ -36,7 +18,7 @@ private:
bool mHasNoiseScaleModel;
private:
void LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info);
void LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const Waifu2x::stInfo &info);
Waifu2x::eWaifu2xError LoadParameterFromJson(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path
, const boost::filesystem::path &modelbin_path, const boost::filesystem::path &caffemodel_path, const std::string &process);
Waifu2x::eWaifu2xError SetParameter(caffe::NetParameter &param, const std::string &process) const;
@ -45,9 +27,9 @@ public:
cNet();
~cNet();
static Waifu2x::eWaifu2xError GetInfo(const boost::filesystem::path &info_path, stInfo &info);
static Waifu2x::eWaifu2xError GetInfo(const boost::filesystem::path &info_path, Waifu2x::stInfo &info);
Waifu2x::eWaifu2xError ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process);
Waifu2x::eWaifu2xError ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const Waifu2x::stInfo &info, const std::string &process);
int GetInputPlane() const;
int GetInnerScale() const;

View File

@ -1149,3 +1149,14 @@ std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
return cNet::GetModelName(info_path);
}
bool Waifu2x::GetInfo(const boost::filesystem::path &model_dir, stInfo &info)
{
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return false;
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return cNet::GetInfo(info_path, info) == Waifu2x::eWaifu2xError_OK;
}

View File

@ -59,6 +59,25 @@ public:
class Waifu2x
{
public:
struct stInfo
{
struct stParam
{
int scale_factor;
int offset;
};
std::string name;
std::string arch_name;
bool has_noise_scale;
int channels;
int recommended_crop_size;
stParam noise;
stParam scale;
stParam noise_scale;
};
enum eWaifu2xModelType
{
eWaifu2xModelTypeNoise = 0,
@ -183,4 +202,5 @@ public:
const std::string& used_process() const;
static std::string GetModelName(const boost::filesystem::path &model_dir);
static bool GetInfo(const boost::filesystem::path &model_dir, stInfo &info);
};

View File

@ -527,6 +527,18 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
}
), list.end());
bool isRecommendedCropSize = false;
Waifu2x::stInfo info;
if (Waifu2x::GetInfo(model_dir, info) && info.recommended_crop_size > 0)
{
tstring str(to_tstring(info.recommended_crop_size));
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str());
isRecommendedCropSize = true;
}
if (list.size() > 0)
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)TEXT("-----------------------"));
int mindiff = INT_MAX;
int defaultIndex = -1;
for (int i = 0; i < list.size(); i++)
@ -534,13 +546,13 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
const int n = list[i];
tstring str(to_tstring(n));
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str());
const int index = SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str());
const int diff = abs(DefaultCommonDivisor - n);
if (DefaultCommonDivisorRange.first <= n && n <= DefaultCommonDivisorRange.second && diff < mindiff)
{
mindiff = diff;
defaultIndex = i;
defaultIndex = index;
}
}
@ -565,6 +577,9 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
if (defaultIndex == -1)
defaultIndex = defaultListIndex;
if(isRecommendedCropSize)
defaultIndex = 0;
if (GetWindowTextLength(hcrop) == 0)
SendMessage(hcrop, CB_SETCURSEL, defaultIndex, 0);
}