From 138e9f1c43ed2b88c4f819d25af564937091120a Mon Sep 17 00:00:00 2001 From: lltcggie Date: Thu, 18 Jun 2015 05:13:41 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=91=E3=83=A9=E3=83=A1=E3=83=BC=E3=82=BF?= =?UTF-8?q?=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB=E3=81=AE=E8=AA=AD=E3=81=BF?= =?UTF-8?q?=E8=BE=BC=E3=81=BF=E3=82=92=E9=AB=98=E9=80=9F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/waifu2x.cpp | 190 +++++++++++++++++++++++++-------------------- 1 file changed, 105 insertions(+), 85 deletions(-) diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index 7c86d73..58bbf38 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -256,112 +256,132 @@ Waifu2x::eWaifu2xError Waifu2x::CreateZoomColorImage(const cv::Mat &float_image, // 学習したパラメータをファイルから読み込む Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr> net, const std::string ¶m_path) { - rapidjson::Document d; - std::vector jsonBuf; + const std::string caffemodel_path = param_path + ".caffemodel"; - try + FILE *fp = fopen(caffemodel_path.c_str(), "rb"); + const bool isModelExist = fp != nullptr; + if (fp) fclose(fp); + + caffe::NetParameter param; + if (isModelExist && caffe::ReadProtoFromBinaryFile(caffemodel_path, ¶m)) + net->CopyTrainedLayersFrom(param); + else { - FILE *fp = fopen(param_path.c_str(), "rb"); - if (fp == nullptr) - return eWaifu2xError_FailedOpenModelFile; + rapidjson::Document d; + std::vector jsonBuf; - fseek(fp, 0, SEEK_END); - const auto size = ftell(fp); - fseek(fp, 0, SEEK_SET); - - jsonBuf.resize(size + 1); - fread(jsonBuf.data(), 1, size, fp); - - fclose(fp); - - jsonBuf[jsonBuf.size() - 1] = '\0'; - - d.Parse(jsonBuf.data()); - } - catch (...) - { - return eWaifu2xError_FailedParseModelFile; - } - - std::vector>> list; - auto &v = net->layers(); - for (auto &l : v) - { - auto lk = l->type(); - auto &bv = l->blobs(); - if (bv.size() > 0) - list.push_back(l); - } - - try - { - int count = 0; - for (auto it = d.Begin(); it != d.End(); ++it) + try { - const auto &weight = (*it)["weight"]; - const auto nInputPlane = (*it)["nInputPlane"].GetInt(); - const auto nOutputPlane = (*it)["nOutputPlane"].GetInt(); - const auto kW = (*it)["kW"].GetInt(); - const auto &bias = (*it)["bias"]; + FILE *fp = fopen(param_path.c_str(), "rb"); + if (fp == nullptr) + return eWaifu2xError_FailedOpenModelFile; - auto leyer = list[count]; + fseek(fp, 0, SEEK_END); + const auto size = ftell(fp); + fseek(fp, 0, SEEK_SET); - auto &b0 = leyer->blobs()[0]; - auto &b1 = leyer->blobs()[1]; + jsonBuf.resize(size + 1); + fread(jsonBuf.data(), 1, size, fp); - float *b0Ptr = nullptr; - float *b1Ptr = nullptr; + fclose(fp); - if (caffe::Caffe::mode() == caffe::Caffe::CPU) - { - b0Ptr = b0->mutable_cpu_data(); - b1Ptr = b1->mutable_cpu_data(); - } - else - { - b0Ptr = b0->mutable_gpu_data(); - b1Ptr = b1->mutable_gpu_data(); - } + jsonBuf[jsonBuf.size() - 1] = '\0'; - const auto WeightSize1 = weight.Size(); - const auto WeightSize2 = weight[0].Size(); - const auto KernelHeight = weight[0][0].Size(); - const auto KernelWidth = weight[0][0][0].Size(); + d.Parse(jsonBuf.data()); + } + catch (...) + { + return eWaifu2xError_FailedParseModelFile; + } - if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth)) - return eWaifu2xError_FailedConstructModel; + std::vector>> list; + auto &v = net->layers(); + for (auto &l : v) + { + auto lk = l->type(); + auto &bv = l->blobs(); + if (bv.size() > 0) + list.push_back(l); + } - if (!(b1->count() == bias.Size())) - return eWaifu2xError_FailedConstructModel; - - size_t weightCount = 0; + try + { std::vector weightList; - for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2) + std::vector biasList; + + int count = 0; + for (auto it = d.Begin(); it != d.End(); ++it) { - for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3) + const auto &weight = (*it)["weight"]; + const auto nInputPlane = (*it)["nInputPlane"].GetInt(); + const auto nOutputPlane = (*it)["nOutputPlane"].GetInt(); + const auto kW = (*it)["kW"].GetInt(); + const auto &bias = (*it)["bias"]; + + auto leyer = list[count]; + + auto &b0 = leyer->blobs()[0]; + auto &b1 = leyer->blobs()[1]; + + float *b0Ptr = nullptr; + float *b1Ptr = nullptr; + + if (caffe::Caffe::mode() == caffe::Caffe::CPU) { - for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4) + b0Ptr = b0->mutable_cpu_data(); + b1Ptr = b1->mutable_cpu_data(); + } + else + { + b0Ptr = b0->mutable_gpu_data(); + b1Ptr = b1->mutable_gpu_data(); + } + + const auto WeightSize1 = weight.Size(); + const auto WeightSize2 = weight[0].Size(); + const auto KernelHeight = weight[0][0].Size(); + const auto KernelWidth = weight[0][0][0].Size(); + + if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth)) + return eWaifu2xError_FailedConstructModel; + + if (!(b1->count() == bias.Size())) + return eWaifu2xError_FailedConstructModel; + + weightList.resize(0); + biasList.resize(0); + + size_t weightCount = 0; + for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2) + { + for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3) { - for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5) - weightList.push_back((float)it5->GetDouble()); + for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4) + { + for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5) + weightList.push_back((float)it5->GetDouble()); + } } } + + caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr); + + for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2) + biasList.push_back((float)it2->GetDouble()); + + caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr); + + count++; } - caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr); + net->ToProto(¶m); - std::vector biasList; - for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2) - biasList.push_back((float)it2->GetDouble()); - - caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr); - - count++; + caffe::WriteProtoToBinaryFile(param, caffemodel_path); + } + catch (...) + { + return eWaifu2xError_FailedConstructModel; } - } - catch (...) - { - return eWaifu2xError_FailedConstructModel; } return eWaifu2xError_OK;