cuDNNのバージョンチェックを入れて、Waifu2x::can_use_cuDNN()で失敗した原因を返すようにした

This commit is contained in:
lltcggie 2015-06-03 13:51:48 +09:00
parent 31b70fe277
commit d6979ff51d
3 changed files with 47 additions and 13 deletions

View File

@ -48,29 +48,42 @@ Waifu2x::~Waifu2x()
}
// cuDNNが使えるかチェック。現状Windowsのみ
bool Waifu2x::can_use_cuDNN()
Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
{
static bool cuDNNFlag = false;
static eWaifu2xcuDNNError cuDNNFlag = eWaifu2xcuDNNError_NotFind;
std::call_once(waifu2x_cudnn_once_flag, [&]()
{
#if defined(WIN32) || defined(WIN64)
HMODULE hModule = LoadLibrary(TEXT("cudnn64_65.dll"));
if (hModule != NULL)
{
typedef cudnnStatus_t(*cudnnCreateType)(cudnnHandle_t *);
typedef cudnnStatus_t(*cudnnDestroyType)(cudnnHandle_t);
typedef cudnnStatus_t(*__stdcall cudnnCreateType)(cudnnHandle_t *);
typedef cudnnStatus_t(*__stdcall cudnnDestroyType)(cudnnHandle_t);
typedef uint64_t(*__stdcall cudnnGetVersionType)();
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate");
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr)
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr)
{
cudnnHandle_t h;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
if (cudnnGetVersionFunc() >= 2000)
{
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS)
cuDNNFlag = true;
cudnnHandle_t h;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
{
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS)
cuDNNFlag = eWaifu2xcuDNNError_OK;
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_OldVersion;
}
else
cuDNNFlag = eWaifu2xcuDNNError_NotFind;
FreeLibrary(hModule);
}
@ -505,7 +518,7 @@ Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &M
if (process == "gpu")
{
// cuDNNが使えそうならcuDNNを使う
if (can_use_cuDNN())
if (can_use_cuDNN() == eWaifu2xcuDNNError_OK)
process = "cudnn";
}

View File

@ -32,6 +32,14 @@ public:
eWaifu2xError_FailedProcessCaffe,
};
enum eWaifu2xcuDNNError
{
eWaifu2xcuDNNError_OK = 0,
eWaifu2xcuDNNError_NotFind,
eWaifu2xcuDNNError_OldVersion,
eWaifu2xcuDNNError_CannotCreate,
};
typedef std::function<bool()> waifu2xCancelFunc;
private:
@ -74,7 +82,7 @@ public:
Waifu2x();
~Waifu2x();
static bool can_use_cuDNN();
static eWaifu2xcuDNNError can_use_cuDNN();
// mode: noise or scale or noise_scale or auto_scale
// process: cpu or gpu or cudnn

View File

@ -591,10 +591,23 @@ public:
void CheckCUDNN(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpData)
{
if (Waifu2x::can_use_cuDNN())
switch (Waifu2x::can_use_cuDNN())
{
case Waifu2x::eWaifu2xcuDNNError_OK:
MessageBox(dh, TEXT("cuDNNが使えます"), TEXT("結果"), MB_OK | MB_ICONINFORMATION);
else
break;
case Waifu2x::eWaifu2xcuDNNError_NotFind:
MessageBox(dh, TEXT("cuDNNは使えません\r\n「cudnn64_65.dll」が見つかりません"), TEXT("結果"), MB_OK | MB_ICONERROR);
break;
case Waifu2x::eWaifu2xcuDNNError_OldVersion:
MessageBox(dh, TEXT("cuDNNは使えません\r\n「cudnn64_65.dll」のバージョンが古いです。v2を使って下さい。"), TEXT("結果"), MB_OK | MB_ICONERROR);
break;
case Waifu2x::eWaifu2xcuDNNError_CannotCreate:
MessageBox(dh, TEXT("cuDNNは使えません\r\ncuDNNを初期化出来ません"), TEXT("結果"), MB_OK | MB_ICONERROR);
break;
default:
MessageBox(dh, TEXT("cuDNNは使えません"), TEXT("結果"), MB_OK | MB_ICONERROR);
}
}
// ここで渡されるhWndはIDC_EDITのHWND(コントロールのイベントだから)