パラメータファイルの読み込みを高速化

This commit is contained in:
lltcggie 2015-06-18 05:13:41 +09:00
parent 097a3c8da3
commit 138e9f1c43

View File

@ -255,6 +255,17 @@ Waifu2x::eWaifu2xError Waifu2x::CreateZoomColorImage(const cv::Mat &float_image,
// 学習したパラメータをファイルから読み込む
Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float>> net, const std::string &param_path)
{
const std::string caffemodel_path = param_path + ".caffemodel";
FILE *fp = fopen(caffemodel_path.c_str(), "rb");
const bool isModelExist = fp != nullptr;
if (fp) fclose(fp);
caffe::NetParameter param;
if (isModelExist && caffe::ReadProtoFromBinaryFile(caffemodel_path, &param))
net->CopyTrainedLayersFrom(param);
else
{
rapidjson::Document d;
std::vector<char> jsonBuf;
@ -295,6 +306,9 @@ Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float
try
{
std::vector<float> weightList;
std::vector<float> biasList;
int count = 0;
for (auto it = d.Begin(); it != d.End(); ++it)
{
@ -334,8 +348,10 @@ Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float
if (!(b1->count() == bias.Size()))
return eWaifu2xError_FailedConstructModel;
weightList.resize(0);
biasList.resize(0);
size_t weightCount = 0;
std::vector<float> weightList;
for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2)
{
for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3)
@ -350,7 +366,6 @@ Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float
caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr);
std::vector<float> biasList;
for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2)
biasList.push_back((float)it2->GetDouble());
@ -358,11 +373,16 @@ Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float
count++;
}
net->ToProto(&param);
caffe::WriteProtoToBinaryFile(param, caffemodel_path);
}
catch (...)
{
return eWaifu2xError_FailedConstructModel;
}
}
return eWaifu2xError_OK;
}