torchvision源代码浅析【1】

pytorch有一个配套的python包叫做torchvision,里面包含了很多好用的cv模型和数据集。

github地址: https://github.com/pytorch/vision

这个系列的文章会简析一下其内部代码及作用。

第一篇文章是overview。

从首页可以看到其中有四个主要的模块。

vision.datasets : Data loaders for popular vision dataset.

 预置数据集,方便读取。

vision.models : Definitions for popular model architectures, such as AlexNet, VGG, and ResNet and pre-trained models.

 主流模型,并且提供预训练权值下载。

vision.transforms : Common image transformations such as random crop, rotations etc.

 图像扩充时用来变幻。

vision.utils : Useful stuff such as saving tensor (3 x H x W) as image to disk, given a mini-batch creating a grid of images, etc.

 图像储存工具。

数据集

MNIST
COCO (Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
SVHN
PhotoTour

模型

AlexNet: AlexNet variant from the "One weird trick" paper.
VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1

变形

transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                          std = [ 0.229, 0.224, 0.225 ]),
])