1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay): net = nn.DataParallel(net, device_ids=devices).to(devices[0]) trainer = torch.optim.SGD( (param for param in net.parameters() if param.requires_grad), lr=lr, momentum=0.9, weight_decay=wd) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) num_batches, timer = len(train_iter), d2l.Timer() legend = ['train loss'] if valid_iter is not None: legend.append('valid loss') animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend) for epoch in range(num_epochs): metric = d2l.Accumulator(2) for i, (features, labels) in enumerate(train_iter): timer.start() features, labels = features.to(devices[0]), labels.to(devices[0]) trainer.zero_grad() output = net(features) l = loss(output, labels).sum() l.backward() trainer.step() metric.add(l, labels.shape[0]) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None)) measures = f'train loss {metric[0] / metric[1]:.3f}' if valid_iter is not None:
|