时间:2025-06-23 09:11
人气:
作者:admin
基于之前的博客 pytorch入门 - AlexNet神经网络,并借助Kaggle 的 Dogs vs Cats Redux 数据集,实现一个基于 AlexNet 的二分类模型识别猫与狗。
完整流程涵盖数据准备、归一化、模型定义、训练增强、验证并可视化结果。
import os
import shutil
def split_data(ROOT_TRAIN):
cat_dir = os.path.join(ROOT_TRAIN, "cat")
dog_dir = os.path.join(ROOT_TRAIN, "dog")
os.makedirs(cat_dir, exist_ok=True)
os.makedirs(dog_dir, exist_ok=True)
for filename in os.listdir(ROOT_TRAIN):
if filename.startswith("cat") and filename.endswith(".jpg"):
shutil.move(os.path.join(ROOT_TRAIN, filename),
os.path.join(cat_dir, filename))
elif filename.startswith("dog") and filename.endswith(".jpg"):
shutil.move(os.path.join(ROOT_TRAIN, filename),
os.path.join(dog_dir, filename))
优化原因:
分类任务需明确标签与数据的对应关系。通过创建cat/dog子目录并移动图片,可直接利用PyTorch的ImageFolder自动生成标签,避免手动标注错误。
def compute_normalization_params(dataset_path):
transform = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor()
])
dataset = ImageFolder(dataset_path, transform=transform)
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False)
# 计算各通道均值和标准差
mean = 0.0
std = 0.0
for data, _ in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
return mean / len(dataset), std / len(dataset)
关键点:
227×227,需提前调整class AlexNet(nn.Module):
def __init__(self):
super().__init__()
# 修改1:输入通道调整为3 (RGB)
self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4)
# ... (中间层省略)
# 修改2:输出层调整为2分类
self.fc3 = nn.Linear(4096, 2)
# 修改3:降低Dropout比例
self.dropout = nn.Dropout(0.2) # 原论文为0.5
优化逻辑:
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(227, scale=(0.8, 1.0)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.488, 0.455, 0.417],
std=[0.226, 0.221, 0.221])
])
增强目的:
# 1. 学习率调整
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 原常用值0.001
# 2. 训练-验证集拆分
train_data, val_data = random_split(dataset, [0.8, 0.2])
# 3. 早停机制
if val_acc > best_acc:
best_model_wts = copy.deepcopy(model.state_dict())
关键技术点:
torch.cuda.amp自动混合精度,提升训练速度30%+(需GPU支持)| 优化点 | 原始值 | 调整值 | 作用 |
|---|---|---|---|
| 输入通道 | 1 (灰度) | 3 (RGB) | 适配彩色图像 |
| 输出维度 | 1000 | 2 | 二分类需求 |
| Dropout率 | 0.5 | 0.2 | 防欠拟合 |
| 学习率 | 0.001 | 0.0001 | 稳定微调 |
| 数据增强 | 无 | 5种变换 | 提升泛化性 |