diff --git a/common/cNet.cpp b/common/cNet.cpp index 8109d8a..3eb9691 100644 --- a/common/cNet.cpp +++ b/common/cNet.cpp @@ -184,6 +184,45 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const boost::filesystem::path &model_p return Waifu2x::eWaifu2xError_OK; } +namespace +{ + Waifu2x::eWaifu2xError ReadJson(const boost::filesystem::path &info_path, rapidjson::Document &d, std::vector &jsonBuf) + { + try + { + boost::iostreams::stream is; + + try + { + is.open(info_path, std::ios_base::in | std::ios_base::binary); + } + catch (...) + { + return Waifu2x::eWaifu2xError_FailedOpenModelFile; + } + + if (!is) + return Waifu2x::eWaifu2xError_FailedOpenModelFile; + + const size_t size = is.seekg(0, std::ios::end).tellg(); + is.seekg(0, std::ios::beg); + + jsonBuf.resize(size + 1); + is.read(jsonBuf.data(), jsonBuf.size()); + + jsonBuf[jsonBuf.size() - 1] = '\0'; + + d.Parse(jsonBuf.data()); + } + catch (...) + { + return Waifu2x::eWaifu2xError_FailedParseModelFile; + } + + return Waifu2x::eWaifu2xError_OK; + } +}; + Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &info_path) { rapidjson::Document d; @@ -191,29 +230,11 @@ Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &inf try { - boost::iostreams::stream is; + Waifu2x::eWaifu2xError ret; - try - { - is.open(info_path, std::ios_base::in | std::ios_base::binary); - } - catch (...) - { - return Waifu2x::eWaifu2xError_FailedOpenModelFile; - } - - if (!is) - return Waifu2x::eWaifu2xError_FailedOpenModelFile; - - const size_t size = is.seekg(0, std::ios::end).tellg(); - is.seekg(0, std::ios::beg); - - jsonBuf.resize(size + 1); - is.read(jsonBuf.data(), jsonBuf.size()); - - jsonBuf[jsonBuf.size() - 1] = '\0'; - - d.Parse(jsonBuf.data()); + ret = ReadJson(info_path, d, jsonBuf); + if (ret != Waifu2x::eWaifu2xError_OK) + return ret; const bool resize = d.HasMember("resize") && d["resize"].GetBool() ? true : false; const auto name = d["name"].GetString(); @@ -708,3 +729,29 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_ return Waifu2x::eWaifu2xError_OK; } + +std::string cNet::GetModelName(const boost::filesystem::path &info_path) +{ + rapidjson::Document d; + std::vector jsonBuf; + std::string str; + + try + { + Waifu2x::eWaifu2xError ret; + + ret = ReadJson(info_path, d, jsonBuf); + if (ret != Waifu2x::eWaifu2xError_OK) + return str; + + const auto name = d["name"].GetString(); + + str = name; + } + catch (...) + { + } + + return str; +} + diff --git a/common/cNet.h b/common/cNet.h index 4963acd..4141de3 100644 --- a/common/cNet.h +++ b/common/cNet.h @@ -1,5 +1,6 @@ #pragma once +#include #include "waifu2x.h" @@ -34,4 +35,6 @@ public: int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const; Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat); + + static std::string GetModelName(const boost::filesystem::path &info_path); }; diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index f778e5e..53bdb3f 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -285,19 +285,7 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le const auto cuDNNCheckEndTime = std::chrono::system_clock::now(); - boost::filesystem::path mode_dir_path(model_dir); - if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す - { - // まずはカレントディレクトリ下にあるか探す - mode_dir_path = boost::filesystem::absolute(model_dir); - if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す - { - boost::filesystem::path a0(ExeDir); - if (a0.is_absolute()) - mode_dir_path = a0.branch_path() / model_dir; - } - } - + const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir)); if (!boost::filesystem::exists(mode_dir_path)) return Waifu2x::eWaifu2xError_FailedOpenModelFile; @@ -317,13 +305,14 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le // TODO: ノイズ除去と拡大を同時に行うネットワークへの対処を考える + const boost::filesystem::path info_path = GetInfoPath(mode_dir_path); + if (mode == "noise" || mode == "noise_scale" || mode == "auto_scale") { const std::string base_name = "noise" + std::to_string(noise_level) + "_model"; const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt"); const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json"); - const boost::filesystem::path info_path = mode_dir_path / "info.json"; mNoiseNet.reset(new cNet); @@ -341,7 +330,6 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt"); const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json"); - const boost::filesystem::path info_path = mode_dir_path / "info.json"; mScaleNet.reset(new cNet); @@ -369,6 +357,31 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le return Waifu2x::eWaifu2xError_OK; } +boost::filesystem::path Waifu2x::GetModeDirPath(const boost::filesystem::path &model_dir) +{ + boost::filesystem::path mode_dir_path(model_dir); + if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す + { + // まずはカレントディレクトリ下にあるか探す + mode_dir_path = boost::filesystem::absolute(model_dir); + if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す + { + boost::filesystem::path a0(ExeDir); + if (a0.is_absolute()) + mode_dir_path = a0.branch_path() / model_dir; + } + } + + return mode_dir_path; +} + +boost::filesystem::path Waifu2x::GetInfoPath(const boost::filesystem::path &mode_dir_path) +{ + const boost::filesystem::path info_path = mode_dir_path / "info.json"; + + return info_path; +} + Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file, const double factor, const waifu2xCancelFunc cancel_func, const int crop_w, const int crop_h, const boost::optional output_quality, const int output_depth, const bool use_tta, @@ -652,3 +665,14 @@ const std::string& Waifu2x::used_process() const { return mProcess; } + +std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir) +{ + const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir)); + if (!boost::filesystem::exists(mode_dir_path)) + return std::string(); + + const boost::filesystem::path info_path = mode_dir_path / "info.json"; + + return cNet::GetModelName(info_path); +} diff --git a/common/waifu2x.h b/common/waifu2x.h index e2e22e7..d3057e7 100644 --- a/common/waifu2x.h +++ b/common/waifu2x.h @@ -87,7 +87,8 @@ private: size_t mOutputBlockSize; private: - eWaifu2xError ReconstructImage(boost::shared_ptr> net, const int reconstructed_scale, cv::Mat &im); + static boost::filesystem::path GetModeDirPath(const boost::filesystem::path &model_dir); + static boost::filesystem::path GetInfoPath(const boost::filesystem::path &model_dir); Waifu2x::eWaifu2xError ReconstructImage(const double factor, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image); @@ -131,4 +132,6 @@ public: void Destroy(); const std::string& used_process() const; + + static std::string GetModelName(const boost::filesystem::path &model_dir); }; diff --git a/waifu2x-caffe-gui/MainDialog.cpp b/waifu2x-caffe-gui/MainDialog.cpp index 007d23d..d8fc82f 100644 --- a/waifu2x-caffe-gui/MainDialog.cpp +++ b/waifu2x-caffe-gui/MainDialog.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -104,27 +105,21 @@ std::vector CommonDivisorList(const int N) tstring DialogEvent::AddName() const { - tstring addstr; + tstring addstr(TEXT("(")); - addstr += TEXT("("); - switch (modelType) + const std::string ModelName = Waifu2x::GetModelName(model_dir); + +#ifdef UNICODE { - case eModelTypeRGB: - addstr += TEXT("RGB"); - break; + std::wstring_convert, wchar_t> cv; + const std::wstring wModelName = cv.from_bytes(ModelName); - case eModelTypePhoto: - addstr += TEXT("Photo"); - break; - - case eModelTypeY: - addstr += TEXT("Y"); - break; - - case eModelTypeUpConvRGB: - addstr += TEXT("UpConvRGB"); - break; + addstr += wModelName; } +#else + addstr += ModelName; +#endif + addstr += TEXT(")"); addstr += TEXT("("); @@ -276,20 +271,20 @@ bool DialogEvent::SyncMember(const bool NotSyncCropSize, const bool silent) break; case 1: - model_dir = TEXT("models/anime_style_art"); - modelType = eModelTypeY; - break; - - case 2: model_dir = TEXT("models/photo"); modelType = eModelTypePhoto; break; - case 3: + case 2: model_dir = TEXT("models/upconv_7_anime_style_art_rgb"); modelType = eModelTypeUpConvRGB; break; + case 3: + model_dir = TEXT("models/anime_style_art"); + modelType = eModelTypeY; + break; + default: break; } @@ -1817,7 +1812,7 @@ void DialogEvent::Create(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpData) index = 0; else if (modelType == eModelTypePhoto) index = 1; - else if (modelType == eModelTypeY) + else if (modelType == eModelTypeUpConvRGB) index = 2; else index = 3;