mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-26 13:42:48 +00:00
パラメータファイルの読み込みを高速化
This commit is contained in:
parent
097a3c8da3
commit
138e9f1c43
@ -256,112 +256,132 @@ Waifu2x::eWaifu2xError Waifu2x::CreateZoomColorImage(const cv::Mat &float_image,
|
|||||||
// 学習したパラメータをファイルから読み込む
|
// 学習したパラメータをファイルから読み込む
|
||||||
Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float>> net, const std::string ¶m_path)
|
Waifu2x::eWaifu2xError Waifu2x::LoadParameter(boost::shared_ptr<caffe::Net<float>> net, const std::string ¶m_path)
|
||||||
{
|
{
|
||||||
rapidjson::Document d;
|
const std::string caffemodel_path = param_path + ".caffemodel";
|
||||||
std::vector<char> jsonBuf;
|
|
||||||
|
|
||||||
try
|
FILE *fp = fopen(caffemodel_path.c_str(), "rb");
|
||||||
|
const bool isModelExist = fp != nullptr;
|
||||||
|
if (fp) fclose(fp);
|
||||||
|
|
||||||
|
caffe::NetParameter param;
|
||||||
|
if (isModelExist && caffe::ReadProtoFromBinaryFile(caffemodel_path, ¶m))
|
||||||
|
net->CopyTrainedLayersFrom(param);
|
||||||
|
else
|
||||||
{
|
{
|
||||||
FILE *fp = fopen(param_path.c_str(), "rb");
|
rapidjson::Document d;
|
||||||
if (fp == nullptr)
|
std::vector<char> jsonBuf;
|
||||||
return eWaifu2xError_FailedOpenModelFile;
|
|
||||||
|
|
||||||
fseek(fp, 0, SEEK_END);
|
try
|
||||||
const auto size = ftell(fp);
|
|
||||||
fseek(fp, 0, SEEK_SET);
|
|
||||||
|
|
||||||
jsonBuf.resize(size + 1);
|
|
||||||
fread(jsonBuf.data(), 1, size, fp);
|
|
||||||
|
|
||||||
fclose(fp);
|
|
||||||
|
|
||||||
jsonBuf[jsonBuf.size() - 1] = '\0';
|
|
||||||
|
|
||||||
d.Parse(jsonBuf.data());
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
return eWaifu2xError_FailedParseModelFile;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<boost::shared_ptr<caffe::Layer<float>>> list;
|
|
||||||
auto &v = net->layers();
|
|
||||||
for (auto &l : v)
|
|
||||||
{
|
|
||||||
auto lk = l->type();
|
|
||||||
auto &bv = l->blobs();
|
|
||||||
if (bv.size() > 0)
|
|
||||||
list.push_back(l);
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
int count = 0;
|
|
||||||
for (auto it = d.Begin(); it != d.End(); ++it)
|
|
||||||
{
|
{
|
||||||
const auto &weight = (*it)["weight"];
|
FILE *fp = fopen(param_path.c_str(), "rb");
|
||||||
const auto nInputPlane = (*it)["nInputPlane"].GetInt();
|
if (fp == nullptr)
|
||||||
const auto nOutputPlane = (*it)["nOutputPlane"].GetInt();
|
return eWaifu2xError_FailedOpenModelFile;
|
||||||
const auto kW = (*it)["kW"].GetInt();
|
|
||||||
const auto &bias = (*it)["bias"];
|
|
||||||
|
|
||||||
auto leyer = list[count];
|
fseek(fp, 0, SEEK_END);
|
||||||
|
const auto size = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_SET);
|
||||||
|
|
||||||
auto &b0 = leyer->blobs()[0];
|
jsonBuf.resize(size + 1);
|
||||||
auto &b1 = leyer->blobs()[1];
|
fread(jsonBuf.data(), 1, size, fp);
|
||||||
|
|
||||||
float *b0Ptr = nullptr;
|
fclose(fp);
|
||||||
float *b1Ptr = nullptr;
|
|
||||||
|
|
||||||
if (caffe::Caffe::mode() == caffe::Caffe::CPU)
|
jsonBuf[jsonBuf.size() - 1] = '\0';
|
||||||
{
|
|
||||||
b0Ptr = b0->mutable_cpu_data();
|
|
||||||
b1Ptr = b1->mutable_cpu_data();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
b0Ptr = b0->mutable_gpu_data();
|
|
||||||
b1Ptr = b1->mutable_gpu_data();
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto WeightSize1 = weight.Size();
|
d.Parse(jsonBuf.data());
|
||||||
const auto WeightSize2 = weight[0].Size();
|
}
|
||||||
const auto KernelHeight = weight[0][0].Size();
|
catch (...)
|
||||||
const auto KernelWidth = weight[0][0][0].Size();
|
{
|
||||||
|
return eWaifu2xError_FailedParseModelFile;
|
||||||
|
}
|
||||||
|
|
||||||
if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth))
|
std::vector<boost::shared_ptr<caffe::Layer<float>>> list;
|
||||||
return eWaifu2xError_FailedConstructModel;
|
auto &v = net->layers();
|
||||||
|
for (auto &l : v)
|
||||||
|
{
|
||||||
|
auto lk = l->type();
|
||||||
|
auto &bv = l->blobs();
|
||||||
|
if (bv.size() > 0)
|
||||||
|
list.push_back(l);
|
||||||
|
}
|
||||||
|
|
||||||
if (!(b1->count() == bias.Size()))
|
try
|
||||||
return eWaifu2xError_FailedConstructModel;
|
{
|
||||||
|
|
||||||
size_t weightCount = 0;
|
|
||||||
std::vector<float> weightList;
|
std::vector<float> weightList;
|
||||||
for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2)
|
std::vector<float> biasList;
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
for (auto it = d.Begin(); it != d.End(); ++it)
|
||||||
{
|
{
|
||||||
for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3)
|
const auto &weight = (*it)["weight"];
|
||||||
|
const auto nInputPlane = (*it)["nInputPlane"].GetInt();
|
||||||
|
const auto nOutputPlane = (*it)["nOutputPlane"].GetInt();
|
||||||
|
const auto kW = (*it)["kW"].GetInt();
|
||||||
|
const auto &bias = (*it)["bias"];
|
||||||
|
|
||||||
|
auto leyer = list[count];
|
||||||
|
|
||||||
|
auto &b0 = leyer->blobs()[0];
|
||||||
|
auto &b1 = leyer->blobs()[1];
|
||||||
|
|
||||||
|
float *b0Ptr = nullptr;
|
||||||
|
float *b1Ptr = nullptr;
|
||||||
|
|
||||||
|
if (caffe::Caffe::mode() == caffe::Caffe::CPU)
|
||||||
{
|
{
|
||||||
for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4)
|
b0Ptr = b0->mutable_cpu_data();
|
||||||
|
b1Ptr = b1->mutable_cpu_data();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
b0Ptr = b0->mutable_gpu_data();
|
||||||
|
b1Ptr = b1->mutable_gpu_data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto WeightSize1 = weight.Size();
|
||||||
|
const auto WeightSize2 = weight[0].Size();
|
||||||
|
const auto KernelHeight = weight[0][0].Size();
|
||||||
|
const auto KernelWidth = weight[0][0][0].Size();
|
||||||
|
|
||||||
|
if (!(b0->count() == WeightSize1 * WeightSize2 * KernelHeight * KernelWidth))
|
||||||
|
return eWaifu2xError_FailedConstructModel;
|
||||||
|
|
||||||
|
if (!(b1->count() == bias.Size()))
|
||||||
|
return eWaifu2xError_FailedConstructModel;
|
||||||
|
|
||||||
|
weightList.resize(0);
|
||||||
|
biasList.resize(0);
|
||||||
|
|
||||||
|
size_t weightCount = 0;
|
||||||
|
for (auto it2 = weight.Begin(); it2 != weight.End(); ++it2)
|
||||||
|
{
|
||||||
|
for (auto it3 = (*it2).Begin(); it3 != (*it2).End(); ++it3)
|
||||||
{
|
{
|
||||||
for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5)
|
for (auto it4 = (*it3).Begin(); it4 != (*it3).End(); ++it4)
|
||||||
weightList.push_back((float)it5->GetDouble());
|
{
|
||||||
|
for (auto it5 = (*it4).Begin(); it5 != (*it4).End(); ++it5)
|
||||||
|
weightList.push_back((float)it5->GetDouble());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr);
|
||||||
|
|
||||||
|
for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2)
|
||||||
|
biasList.push_back((float)it2->GetDouble());
|
||||||
|
|
||||||
|
caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr);
|
||||||
|
|
||||||
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
caffe::caffe_copy(b0->count(), weightList.data(), b0Ptr);
|
net->ToProto(¶m);
|
||||||
|
|
||||||
std::vector<float> biasList;
|
caffe::WriteProtoToBinaryFile(param, caffemodel_path);
|
||||||
for (auto it2 = bias.Begin(); it2 != bias.End(); ++it2)
|
}
|
||||||
biasList.push_back((float)it2->GetDouble());
|
catch (...)
|
||||||
|
{
|
||||||
caffe::caffe_copy(b1->count(), biasList.data(), b1Ptr);
|
return eWaifu2xError_FailedConstructModel;
|
||||||
|
|
||||||
count++;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
return eWaifu2xError_FailedConstructModel;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return eWaifu2xError_OK;
|
return eWaifu2xError_OK;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user