白云岛资源网 Design By www.pvray.com
模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。
来源
定义模型:
'''VGG11/13/16/19 in Pytorch.''' import torch import torch.nn as nn from torch.autograd import Variable cfg = { 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } # 模型需继承nn.Module class VGG(nn.Module): # 初始化参数: def __init__(self, vgg_name): super(VGG, self).__init__() self.features = self._make_layers(cfg[vgg_name]) self.classifier = nn.Linear(512, 10) # 模型计算时的前向过程,也就是按照这个过程进行计算 def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out def _make_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers) # net = VGG('VGG11') # x = torch.randn(2,3,32,32) # print(net(Variable(x)).size())
定义训练过程:
'''Train CIFAR10 with PyTorch.''' from __future__ import print_function import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as transforms import os import argparse from models import * from utils import progress_bar from torch.autograd import Variable # 获取参数 parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') args = parser.parse_args() use_cuda = torch.cuda.is_available() best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch # 获取数据集,并先进行预处理 print('==> Preparing data..') # 图像预处理和增强 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 继续训练模型或新建一个模型 if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/ckpt.t7') net = checkpoint['net'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] else: print('==> Building model..') net = VGG('VGG16') # net = ResNet18() # net = PreActResNet18() # net = GoogLeNet() # net = DenseNet121() # net = ResNeXt29_2x64d() # net = MobileNet() # net = MobileNetV2() # net = DPN92() # net = ShuffleNetG2() # net = SENet18() # 如果GPU可用,使用GPU if use_cuda: # move param and buffer to GPU net.cuda() # parallel use GPU net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1)) # speed up slightly cudnn.benchmark = True # 定义度量和优化 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # 训练阶段 def train(epoch): print('\nEpoch: %d' % epoch) # switch to train mode net.train() train_loss = 0 correct = 0 total = 0 # batch 数据 for batch_idx, (inputs, targets) in enumerate(trainloader): # 将数据移到GPU上 if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() # 先将optimizer梯度先置为0 optimizer.zero_grad() # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable inputs, targets = Variable(inputs), Variable(targets) # 模型输出 outputs = net(inputs) # 计算loss,图的终点处 loss = criterion(outputs, targets) # 反向传播,计算梯度 loss.backward() # 更新参数 optimizer.step() # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大 train_loss += loss.data[0] # 数据统计 _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) # 测试阶段 def test(epoch): global best_acc # 先切到测试模型 net.eval() test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) outputs = net(inputs) loss = criterion(outputs, targets) # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph. test_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) # Save checkpoint. # 保存模型 acc = 100.*correct/total if acc > best_acc: print('Saving..') state = { 'net': net.module if use_cuda else net, 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.t7') best_acc = acc # 运行模型 for epoch in range(start_epoch, start_epoch+200): train(epoch) test(epoch) # 清除部分无用变量 torch.cuda.empty_cache()
运行:
新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01
一些utility:
'''Some helper functions for PyTorch, including: - get_mean_and_std: calculate the mean and std value of dataset. - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. ''' import os import sys import time import math import torch.nn as nn import torch.nn.init as init def get_mean_and_std(dataset): '''Compute the mean and std value of dataset.''' dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') for inputs, targets in dataloader: for i in range(3): mean[i] += inputs[:,i,:,:].mean() std[i] += inputs[:,i,:,:].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std def init_params(net): '''Init layer parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal(m.weight, mode='fan_out') if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) _, term_width = os.popen('stty size', 'r').read().split() term_width = int(term_width) TOTAL_BAR_LENGTH = 65. last_time = time.time() begin_time = last_time def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: begin_time = time.time() # Reset for new bar. cur_len = int(TOTAL_BAR_LENGTH*current/total) rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 sys.stdout.write(' [') for i in range(cur_len): sys.stdout.write('=') sys.stdout.write('>') for i in range(rest_len): sys.stdout.write('.') sys.stdout.write(']') cur_time = time.time() step_time = cur_time - last_time last_time = cur_time tot_time = cur_time - begin_time L = [] L.append(' Step: %s' % format_time(step_time)) L.append(' | Tot: %s' % format_time(tot_time)) if msg: L.append(' | ' + msg) msg = ''.join(L) sys.stdout.write(msg) for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): sys.stdout.write(' ') # Go back to the center of the bar. for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): sys.stdout.write('\b') sys.stdout.write(' %d/%d ' % (current+1, total)) if current < total-1: sys.stdout.write('\r') else: sys.stdout.write('\n') sys.stdout.flush() def format_time(seconds): days = int(seconds / 3600/24) seconds = seconds - days*3600*24 hours = int(seconds / 3600) seconds = seconds - hours*3600 minutes = int(seconds / 60) seconds = seconds - minutes*60 secondsf = int(seconds) seconds = seconds - secondsf millis = int(seconds*1000) f = '' i = 1 if days > 0: f += str(days) + 'D' i += 1 if hours > 0 and i <= 2: f += str(hours) + 'h' i += 1 if minutes > 0 and i <= 2: f += str(minutes) + 'm' i += 1 if secondsf > 0 and i <= 2: f += str(secondsf) + 's' i += 1 if millis > 0 and i <= 2: f += str(millis) + 'ms' i += 1 if f == '': f = '0ms' return f
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
白云岛资源网 Design By www.pvray.com
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
白云岛资源网 Design By www.pvray.com
暂无评论...