预训练的意思是用 torchvision 里写好的 alexnet (修改最后一层),不是指导入训练好的,尝试用 quickstart 里的代码训练 cifar10 ,但是网上普遍查到的实验数据,准确率大概在 80%,78%左右,我迭代到收敛也只能得到 70%的准确率,这个差异产生的原因是啥呢?
完整代码:
from utils import *
from pipeit import *
import os,sys,time,pickle,random
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torchvision import datasets, models
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision.transforms import ToTensor, Lambda, Resize, Compose, InterpolationMode
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
torch.backends.cudnn.benchmark=True
# Download training data from open datasets.
training_data = datasets.CIFAR10(
root=".\\data\\cifar10",
train=True,
download=True,
transform=Compose([
Resize((64, 64), InterpolationMode.BICUBIC),
ToTensor()
])
)
# Download test data from open datasets.
test_data = datasets.CIFAR10(
root=".\\data\\cifar10",
train=False,
download=True,
transform=Compose([
Resize((64, 64), InterpolationMode.BICUBIC),
ToTensor()
])
)
def imshow(training_data):
labels_map = {
0: "plane",
1: "car",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
cols, rows = 3, 3
figure = plt.figure(figsize=(8,8))
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
img = img.swapaxes(0,1)
img = img.swapaxes(1,2)
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img)
plt.show()
# imshow(training_data)
def train_loop(dataloader, net, loss_fn, optimizer):
size = len(dataloader)
train_loss = 0
for batch_idx, (X, tag) in enumerate(dataloader):
X, tag = X.to(device), tag.to(device)
pred = net(X)
loss = loss_fn(pred, tag)
train_loss += loss.item()
# Back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss /= size
return train_loss
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
return test_loss, correct
net = models.alexnet().to(device)
net.classifier[6] = nn.Linear(4096, 10).to(device)
learning_rate = 0.01
batch_size = 128
weight_decay = 0
train_dataloader = DataLoader(training_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size = batch_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = learning_rate)
epochs = 50
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
st_time = time.time()
train_loss = train_loop(train_dataloader, net, loss_fn, optimizer)
test_loss, correct = test_loop(test_dataloader, net, loss_fn)
print(f"Train loss: {train_loss:>8f}, Test loss: {test_loss:>8f}, Accuracy: {(100*correct):>0.1f}%, Epoch time: {time.time() - st_time:.2f}s\n")
print("Done!")
torch.save(net.state_dict(), 'alexnet-pre1.model')
最后收敛时的数据在这样:
Epoch 52
-------------------------------
Train loss: 0.399347, Test loss: 0.970927, Accuracy: 70.3%, Epoch time: 17.20s
1
KangolHsu 2021-11-21 23:53:55 +08:00 via iPhone
输入的图片 64*64 ?是不是有点小啊
|