パラメータファイルの読み込みを高速化

This commit is contained in:
lltcggie 2015-06-18 05:13:41 +09:00
parent 097a3c8da3
commit 138e9f1c43

View File

@ -256,112 +256,132 @@ Waifu2x::eWaifu2xError Waifu2x::CreateZoomColorImage(const cv::Mat &float_image,
// 学習したパラメータをファイルから読み込む // 学習したパラメータをファイルから読み込む
Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float>> net, const std::string &param_path) Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float>> net, const std::string &param_path)
{ {
rapidjson::Document d; const std::string caffemodel_path = param_path + ".caffemodel";
std::vector<char> jsonBuf;
try FILE *fp = fopen(caffemodel_path.c_str(), "rb");
const bool isModelExist = fp != nullptr;
if (fp) fclose(fp);
caffe::NetParameter param;
if (isModelExist && caffe::ReadProtoFromBinaryFile(caffemodel_path, &param))
net->CopyTrainedLayersFrom(param);
else
{ {
FILE *fp = fopen(param_path.c_str(), "rb"); rapidjson::Document d;
if (fp == nullptr) std::vector<char> jsonBuf;
return eWaifu2xError_FailedOpenModelFile;
fseek(fp, 0, SEEK_END); try
const auto size = ftell(fp);
fseek(fp, 0, SEEK_SET);
jsonBuf.resize(size + 1);
fread(jsonBuf.data(), 1, size, fp);
fclose(fp);
jsonBuf[jsonBuf.size() - 1] = '\0';
d.Parse(jsonBuf.data());
}
catch (...)
{
return eWaifu2xError_FailedParseModelFile;
}
std::vector<boost::shared_ptr<caffe::Layer<float>>> list;
auto &v = net->layers();
for (auto &l : v)
{
auto lk = l->type();
auto &bv = l->blobs();
if (bv.size() > 0)
list.push_back(l);
}
try
{
int count = 0;
for (auto it = d.Begin(); it != d.End(); ++it)
{ {
const auto &weight = (*it)["weight"]; FILE *fp = fopen(param_path.c_str(), "rb");
const auto nInputPlane = (*it)["nInputPlane"].GetInt(); if (fp == nullptr)
const auto nOutputPlane = (*it)["nOutputPlane"].GetInt(); return eWaifu2xError_FailedOpenModelFile;
const auto kW = (*it)["kW"].GetInt();
const auto &bias = (*it)["bias"];
auto leyer = list[count]; fseek(fp, 0, SEEK_END);
const auto size = ftell(fp);
fseek(fp, 0, SEEK_SET);
auto &b0 = leyer->blobs()[0]; jsonBuf.resize(size + 1);
auto &b1 = leyer->blobs()[1]; fread(jsonBuf.data(), 1, size, fp);
float *b0Ptr = nullptr; fclose(fp);
float *b1Ptr = nullptr;
if (caffe::Caffe::mode() == caffe::Caffe::CPU) jsonBuf[jsonBuf.size() - 1] = '\0';
{
b0Ptr = b0->mutable_cpu_data();
b1Ptr = b1->mutable_cpu_data();
}
else
{
b0Ptr = b0->mutable_gpu_data();
b1Ptr = b1->mutable_gpu_data();
}
const auto WeightSize1 = weight.Size(); d.Parse(jsonBuf.data());
const auto WeightSize2 = weight[0].Size(); }
const auto KernelHeight = weight[0][0].Size(); catch (...)
const auto KernelWidth = weight[0][0][0].Size(); {
return eWaifu2xError_FailedParseModelFile;
}
if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth)) std::vector<boost::shared_ptr<caffe::Layer<float>>> list;
return eWaifu2xError_FailedConstructModel; auto &v = net->layers();
for (auto &l : v)
{
auto lk = l->type();
auto &bv = l->blobs();
if (bv.size() > 0)
list.push_back(l);
}
if (!(b1->count() == bias.Size())) try
return eWaifu2xError_FailedConstructModel; {
size_t weightCount = 0;
std::vector<float> weightList; std::vector<float> weightList;
for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2) std::vector<float> biasList;
int count = 0;
for (auto it = d.Begin(); it != d.End(); ++it)
{ {
for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3) const auto &weight = (*it)["weight"];
const auto nInputPlane = (*it)["nInputPlane"].GetInt();
const auto nOutputPlane = (*it)["nOutputPlane"].GetInt();
const auto kW = (*it)["kW"].GetInt();
const auto &bias = (*it)["bias"];
auto leyer = list[count];
auto &b0 = leyer->blobs()[0];
auto &b1 = leyer->blobs()[1];
float *b0Ptr = nullptr;
float *b1Ptr = nullptr;
if (caffe::Caffe::mode() == caffe::Caffe::CPU)
{ {
for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4) b0Ptr = b0->mutable_cpu_data();
b1Ptr = b1->mutable_cpu_data();
}
else
{
b0Ptr = b0->mutable_gpu_data();
b1Ptr = b1->mutable_gpu_data();
}
const auto WeightSize1 = weight.Size();
const auto WeightSize2 = weight[0].Size();
const auto KernelHeight = weight[0][0].Size();
const auto KernelWidth = weight[0][0][0].Size();
if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth))
return eWaifu2xError_FailedConstructModel;
if (!(b1->count() == bias.Size()))
return eWaifu2xError_FailedConstructModel;
weightList.resize(0);
biasList.resize(0);
size_t weightCount = 0;
for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2)
{
for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3)
{ {
for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5) for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4)
weightList.push_back((float)it5->GetDouble()); {
for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5)
weightList.push_back((float)it5->GetDouble());
}
} }
} }
caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr);
for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2)
biasList.push_back((float)it2->GetDouble());
caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr);
count++;
} }
caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr); net->ToProto(&param);
std::vector<float> biasList; caffe::WriteProtoToBinaryFile(param, caffemodel_path);
for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2) }
biasList.push_back((float)it2->GetDouble()); catch (...)
{
caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr); return eWaifu2xError_FailedConstructModel;
count++;
} }
}
catch (...)
{
return eWaifu2xError_FailedConstructModel;
} }
return eWaifu2xError_OK; return eWaifu2xError_OK;