mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-26 05:32:47 +00:00
モデルの短縮名をinfo.jsonから取得するようにした
This commit is contained in:
parent
9fa4d905fd
commit
9052df543f
@ -184,11 +184,10 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const boost::filesystem::path &model_p
|
||||
return Waifu2x::eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &info_path)
|
||||
namespace
|
||||
{
|
||||
rapidjson::Document d;
|
||||
std::vector<char> jsonBuf;
|
||||
|
||||
Waifu2x::eWaifu2xError ReadJson(const boost::filesystem::path &info_path, rapidjson::Document &d, std::vector<char> &jsonBuf)
|
||||
{
|
||||
try
|
||||
{
|
||||
boost::iostreams::stream<boost::iostreams::file_descriptor_source> is;
|
||||
@ -214,6 +213,28 @@ Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &inf
|
||||
jsonBuf[jsonBuf.size() - 1] = '\0';
|
||||
|
||||
d.Parse(jsonBuf.data());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return Waifu2x::eWaifu2xError_FailedParseModelFile;
|
||||
}
|
||||
|
||||
return Waifu2x::eWaifu2xError_OK;
|
||||
}
|
||||
};
|
||||
|
||||
Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &info_path)
|
||||
{
|
||||
rapidjson::Document d;
|
||||
std::vector<char> jsonBuf;
|
||||
|
||||
try
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
ret = ReadJson(info_path, d, jsonBuf);
|
||||
if (ret != Waifu2x::eWaifu2xError_OK)
|
||||
return ret;
|
||||
|
||||
const bool resize = d.HasMember("resize") && d["resize"].GetBool() ? true : false;
|
||||
const auto name = d["name"].GetString();
|
||||
@ -708,3 +729,29 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_
|
||||
|
||||
return Waifu2x::eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
std::string cNet::GetModelName(const boost::filesystem::path &info_path)
|
||||
{
|
||||
rapidjson::Document d;
|
||||
std::vector<char> jsonBuf;
|
||||
std::string str;
|
||||
|
||||
try
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
ret = ReadJson(info_path, d, jsonBuf);
|
||||
if (ret != Waifu2x::eWaifu2xError_OK)
|
||||
return str;
|
||||
|
||||
const auto name = d["name"].GetString();
|
||||
|
||||
str = name;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "waifu2x.h"
|
||||
|
||||
|
||||
@ -34,4 +35,6 @@ public:
|
||||
int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const;
|
||||
|
||||
Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat);
|
||||
|
||||
static std::string GetModelName(const boost::filesystem::path &info_path);
|
||||
};
|
||||
|
@ -285,19 +285,7 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
|
||||
|
||||
const auto cuDNNCheckEndTime = std::chrono::system_clock::now();
|
||||
|
||||
boost::filesystem::path mode_dir_path(model_dir);
|
||||
if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す
|
||||
{
|
||||
// まずはカレントディレクトリ下にあるか探す
|
||||
mode_dir_path = boost::filesystem::absolute(model_dir);
|
||||
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す
|
||||
{
|
||||
boost::filesystem::path a0(ExeDir);
|
||||
if (a0.is_absolute())
|
||||
mode_dir_path = a0.branch_path() / model_dir;
|
||||
}
|
||||
}
|
||||
|
||||
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
|
||||
if (!boost::filesystem::exists(mode_dir_path))
|
||||
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
|
||||
|
||||
@ -317,13 +305,14 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
|
||||
|
||||
// TODO: ノイズ除去と拡大を同時に行うネットワークへの対処を考える
|
||||
|
||||
const boost::filesystem::path info_path = GetInfoPath(mode_dir_path);
|
||||
|
||||
if (mode == "noise" || mode == "noise_scale" || mode == "auto_scale")
|
||||
{
|
||||
const std::string base_name = "noise" + std::to_string(noise_level) + "_model";
|
||||
|
||||
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
|
||||
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
|
||||
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
||||
|
||||
mNoiseNet.reset(new cNet);
|
||||
|
||||
@ -341,7 +330,6 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
|
||||
|
||||
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
|
||||
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
|
||||
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
||||
|
||||
mScaleNet.reset(new cNet);
|
||||
|
||||
@ -369,6 +357,31 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
|
||||
return Waifu2x::eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
boost::filesystem::path Waifu2x::GetModeDirPath(const boost::filesystem::path &model_dir)
|
||||
{
|
||||
boost::filesystem::path mode_dir_path(model_dir);
|
||||
if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す
|
||||
{
|
||||
// まずはカレントディレクトリ下にあるか探す
|
||||
mode_dir_path = boost::filesystem::absolute(model_dir);
|
||||
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す
|
||||
{
|
||||
boost::filesystem::path a0(ExeDir);
|
||||
if (a0.is_absolute())
|
||||
mode_dir_path = a0.branch_path() / model_dir;
|
||||
}
|
||||
}
|
||||
|
||||
return mode_dir_path;
|
||||
}
|
||||
|
||||
boost::filesystem::path Waifu2x::GetInfoPath(const boost::filesystem::path &mode_dir_path)
|
||||
{
|
||||
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
||||
|
||||
return info_path;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file,
|
||||
const double factor, const waifu2xCancelFunc cancel_func, const int crop_w, const int crop_h,
|
||||
const boost::optional<int> output_quality, const int output_depth, const bool use_tta,
|
||||
@ -652,3 +665,14 @@ const std::string& Waifu2x::used_process() const
|
||||
{
|
||||
return mProcess;
|
||||
}
|
||||
|
||||
std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
|
||||
{
|
||||
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
|
||||
if (!boost::filesystem::exists(mode_dir_path))
|
||||
return std::string();
|
||||
|
||||
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
||||
|
||||
return cNet::GetModelName(info_path);
|
||||
}
|
||||
|
@ -87,7 +87,8 @@ private:
|
||||
size_t mOutputBlockSize;
|
||||
|
||||
private:
|
||||
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, const int reconstructed_scale, cv::Mat &im);
|
||||
static boost::filesystem::path GetModeDirPath(const boost::filesystem::path &model_dir);
|
||||
static boost::filesystem::path GetInfoPath(const boost::filesystem::path &model_dir);
|
||||
|
||||
Waifu2x::eWaifu2xError ReconstructImage(const double factor, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
|
||||
const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image);
|
||||
@ -131,4 +132,6 @@ public:
|
||||
void Destroy();
|
||||
|
||||
const std::string& used_process() const;
|
||||
|
||||
static std::string GetModelName(const boost::filesystem::path &model_dir);
|
||||
};
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <codecvt>
|
||||
#include <cblas.h>
|
||||
#include <dlgs.h>
|
||||
#include <boost/tokenizer.hpp>
|
||||
@ -104,27 +105,21 @@ std::vector<int> CommonDivisorList(const int N)
|
||||
|
||||
tstring DialogEvent::AddName() const
|
||||
{
|
||||
tstring addstr;
|
||||
tstring addstr(TEXT("("));
|
||||
|
||||
addstr += TEXT("(");
|
||||
switch (modelType)
|
||||
const std::string ModelName = Waifu2x::GetModelName(model_dir);
|
||||
|
||||
#ifdef UNICODE
|
||||
{
|
||||
case eModelTypeRGB:
|
||||
addstr += TEXT("RGB");
|
||||
break;
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> cv;
|
||||
const std::wstring wModelName = cv.from_bytes(ModelName);
|
||||
|
||||
case eModelTypePhoto:
|
||||
addstr += TEXT("Photo");
|
||||
break;
|
||||
|
||||
case eModelTypeY:
|
||||
addstr += TEXT("Y");
|
||||
break;
|
||||
|
||||
case eModelTypeUpConvRGB:
|
||||
addstr += TEXT("UpConvRGB");
|
||||
break;
|
||||
addstr += wModelName;
|
||||
}
|
||||
#else
|
||||
addstr += ModelName;
|
||||
#endif
|
||||
|
||||
addstr += TEXT(")");
|
||||
|
||||
addstr += TEXT("(");
|
||||
@ -276,20 +271,20 @@ bool DialogEvent::SyncMember(const bool NotSyncCropSize, const bool silent)
|
||||
break;
|
||||
|
||||
case 1:
|
||||
model_dir = TEXT("models/anime_style_art");
|
||||
modelType = eModelTypeY;
|
||||
break;
|
||||
|
||||
case 2:
|
||||
model_dir = TEXT("models/photo");
|
||||
modelType = eModelTypePhoto;
|
||||
break;
|
||||
|
||||
case 3:
|
||||
case 2:
|
||||
model_dir = TEXT("models/upconv_7_anime_style_art_rgb");
|
||||
modelType = eModelTypeUpConvRGB;
|
||||
break;
|
||||
|
||||
case 3:
|
||||
model_dir = TEXT("models/anime_style_art");
|
||||
modelType = eModelTypeY;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -1817,7 +1812,7 @@ void DialogEvent::Create(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpData)
|
||||
index = 0;
|
||||
else if (modelType == eModelTypePhoto)
|
||||
index = 1;
|
||||
else if (modelType == eModelTypeY)
|
||||
else if (modelType == eModelTypeUpConvRGB)
|
||||
index = 2;
|
||||
else
|
||||
index = 3;
|
||||
|
Loading…
x
Reference in New Issue
Block a user