mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-26 13:42:48 +00:00
upresnet10_3.prototxt生成スクリプトが正しくないネットワークを吐き出していたのを修正、スクリプトのファイル名を変更
This commit is contained in:
parent
706ee0d32b
commit
65ad2de604
@ -81,35 +81,35 @@ def DeConv(name, bottom, num_output, kernel_size, stride = 1, pad = 0, nobias =
|
||||
|
||||
|
||||
def Sigmoid(name, bottom):
|
||||
top_name = name
|
||||
top_name = name + '_sigmoid'
|
||||
# ReLU
|
||||
sigmoid_layer = caffe_pb2.LayerParameter()
|
||||
sigmoid_layer.name = name + '_sigmoid'
|
||||
sigmoid_layer.type = 'Sigmoid'
|
||||
sigmoid_layer.bottom.extend([top_name])
|
||||
sigmoid_layer.bottom.extend([bottom])
|
||||
sigmoid_layer.top.extend([top_name])
|
||||
return sigmoid_layer
|
||||
|
||||
|
||||
def Relu(name, bottom):
|
||||
top_name = name
|
||||
top_name = name + '_relu'
|
||||
# ReLU
|
||||
relu_layer = caffe_pb2.LayerParameter()
|
||||
relu_layer.name = name + '_relu'
|
||||
relu_layer.type = 'ReLU'
|
||||
relu_layer.bottom.extend([top_name])
|
||||
relu_layer.bottom.extend([bottom])
|
||||
relu_layer.top.extend([top_name])
|
||||
return relu_layer
|
||||
|
||||
|
||||
def LeakyRelu(name, bottom, negative_slope):
|
||||
top_name = name
|
||||
top_name = name + '_relu'
|
||||
# LeakyRelu
|
||||
relu_layer = caffe_pb2.LayerParameter()
|
||||
relu_layer.name = name + '_relu'
|
||||
relu_layer.type = 'ReLU'
|
||||
relu_layer.relu_param.negative_slope = negative_slope
|
||||
relu_layer.bottom.extend([top_name])
|
||||
relu_layer.bottom.extend([bottom])
|
||||
relu_layer.top.extend([top_name])
|
||||
return relu_layer
|
||||
|
||||
@ -122,11 +122,12 @@ def ConvLeakyRelu(name, bottom, num_output, kernel_size, stride = 1, pad = 0, ne
|
||||
|
||||
|
||||
def GlobalAvgPool(name, bottom, stride = 1, pad = 0):
|
||||
top_name = name + '_globalavgpool'
|
||||
layer = caffe_pb2.LayerParameter()
|
||||
layer.name = name + '_globalavgpool'
|
||||
layer.type = 'Pooling'
|
||||
layer.bottom.extend([bottom])
|
||||
layer.top.extend([name])
|
||||
layer.top.extend([top_name])
|
||||
layer.pooling_param.pool = caffe_pb2.PoolingParameter.AVE
|
||||
layer.pooling_param.stride = stride
|
||||
layer.pooling_param.pad = pad
|
||||
@ -240,15 +241,6 @@ def main(args):
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(pb.text_format.MessageToString(model))
|
||||
|
||||
# caffe.set_mode_cpu()
|
||||
# net = caffe.Net(args.output, caffe.TEST)
|
||||
# input_data = np.random.random_sample(net.blobs['input'].data.shape)
|
||||
# net.blobs['input'].data[...] = input_data
|
||||
# ret = net.forward()
|
||||
# print input_data
|
||||
# print ret
|
||||
# print input_data.shape
|
||||
# print ret['/conv_post'].shape
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
Loading…
x
Reference in New Issue
Block a user