同时搞定TensorFlow、PyTorch (三) :资料前置处理

同时搞定TensorFlow、PyTorch (一):梯度下降。同时搞定TensorFlow、PyTorch (二):模型定义。同时搞定TensorFlow、PyTorch (三) :资料前置处理。同时搞定TensorFlow、PyTorch (四):模型训练。

前言

上一篇谈到神经层及神经网路模型的定义,这一次我们就来研究 TensorFlow/PyTorch 如何进行资料前置处理,并训练模型。

资料前置处理

基本上,TensorFlow、PyTorch 都提供NumPy格式相容的资料型态,不管是影像、文字或语音,依程序是读取资料档后转为 NumPy 阵列,经过前置处理再餵入模型,但实务上不会一次载入所有资料,因为记忆体会爆掉,因此,TensorFlow、PyTorch 都支援 Dataset/DataLoader,一次只读取一批(batch)资料进行训练,完成后在读取下一批资料训练,这样才能节省记忆体,接下来就来看看两个套件的作法。

MNIST转Dataset

TensorFlow:载入MNIST资料,格式为 NumPy 阵列,之后经过前置处理再转为Dataset。

import tensorflow as tfmnist = tf.keras.datasets.mnist# 载入 MNIST 手写阿拉伯数字资料(x_train, y_train),(x_test, y_test) = mnist.load_data()# 特徵缩放,使用常态化(Normalization),公式 = (x - min) / (max - min)x_train_norm, x_test_norm = x_train / 255.0, x_test / 255.0# 转为 Dataset,含 X/Y 资料train_ds = tf.data.Dataset.from_tensor_slices((x_train_norm, y_train))

PyTorch:载入MNIST资料,直接经转为Dataset,可透过transform转为PyTorch张量或进行前置处理。

import torchfrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision.datasets import MNIST# 下载 MNIST 手写阿拉伯数字 训练资料train_ds = MNIST('', train=True, download=True,                  transform=transforms.ToTensor())

transform进行前置处理的实作如下,PyTorch直接利用transform进行标準化(Standardization)转换,公式可参阅Scikit-learn:

# 资料转换transform = transforms.Compose(    [transforms.ToTensor(),     # 读入图像範围介于[0, 1]之间,将之转换为 [-1, 1]     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))     # ImageNet     # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))    ])    # 载入资料集,如果出现 BrokenPipeError 错误,将 num_workers 改为 0train_ds = torchvision.datasets.CIFAR10(root='./CIFAR10', train=True,                                        download=True, transform=transform)    

读取档案转Dataset

通常会读取目录内所有档案,作为训练或测试资料。
TensorFlow:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(    './data/training_set',    validation_split=0.2,    subset="training",    seed=1337,    image_size=image_size,    batch_size=batch_size,)

PyTorch:

import torchvisiontrain_ds = torchvision.datasets.ImageFolder('./data/training_set', transform=transform)

均以次目录名称作为标注(Label),例如以下结构:
http://img2.58codes.com/2024/20001976snNOTIKtYt.png

如果不是以上结构,PyTorch 也可以自订Dataset类别,只要实作__init__、len、__getitem__三个方法,程式码如下,完整範例可参考开发者传授 PyTorch 秘笈 的src/06_05_Data_Augmentation_MNIST.ipynb:

class CustomImageDataset(torch.utils.data.Dataset):    def __init__(self, img_dir, transform=None, target_transform=None                 , to_gray=False, size=28):        self.img_labels = [file_name for file_name in os.listdir(img_dir)]        self.img_dir = img_dir        self.transform = transform        self.target_transform = target_transform        self.to_gray = to_gray        self.size = size    def __len__(self):        return len(self.img_labels)    def __getitem__(self, idx):        # 组合档案完整路径        img_path = os.path.join(self.img_dir, self.img_labels[idx])        # 读取图档        mode = 'L' if self.to_gray else 'RGB'        image = Image.open(img_path, mode='r').convert(mode)        image = Image.fromarray(1.0-(np.array(image)/255))        # print(image.shape)        # 去除副档名        label = int(self.img_labels[idx].split('.')[0])                # 转换        if self.transform:            image = self.transform(image)        if self.target_transform:            label = self.target_transform(label)                return image, label

TensorFlow 作法可参阅Creating Custom TensorFlow Dataset。

Dataset 转 DataLoader

TensorFlow:直接将Dataset餵入模型训练即可,不需DataLoader。

# 模型训练model.fit(    train_ds, epochs=epochs, validation_data=val_ds)

PyTorch:需转为DataLoader,再餵入模型训练。

data_loader = torch.utils.data.DataLoader(train_ds, batch_size=10,shuffle=False)def train(model, device, train_loader, criterion, optimizer, epoch):    model.train()    loss_list = []        for batch_idx, (data, target) in enumerate(train_loader):        data, target = data.to(device), target.to(device)                optimizer.zero_grad()        output = model(data)        loss = criterion(output, target)        loss.backward()        optimizer.step()                if (batch_idx+1) % 10 == 0:            loss_list.append(loss.item())            batch = (batch_idx+1) * len(data)            data_count = len(train_loader.dataset)            percentage = (100. * (batch_idx+1) / len(train_loader))            print(f'Epoch {epoch}: [{batch:5d} / {data_count}] ' +                  f'({percentage:.0f} %)  Loss: {loss.item():.6f}')    return loss_list    for epoch in range(1, epochs + 1):    loss_list += train(model, device, train_loader, criterion, optimizer, epoch)
TensorFlow完整程式可参阅深度学习 -- 最佳入门迈向 AI 专题实战的src/05_12_Dataset.ipynb、06_05_Data_Augmentation_MNIST.ipynb、06_06_Data_Augmentation_CIFAR.ipynb。PyTorch可参阅开发者传授 PyTorch 秘笈 的src/05_01_Datasets.ipynb、06_05_Data_Augmentation_MNIST.ipynb、06_06_Data_Augmentation_CIFAR.ipynb。

结论

TensorFlow/PyTorch 基本设计概念是一致的,只是有些细节是存在差异的,例如 TensorFlow Dataset 可以使用cache、prefetch 缩短训练时间。

下一篇我们继续比较模型训练的细节。

以下为工商广告:)。
PyTorch:
开发者传授 PyTorch 秘笈
http://img2.58codes.com/2024/20001976MhL9K2rsgO.png
预计 2022/6/20 出版。

TensorFlow:
深度学习 -- 最佳入门迈向 AI 专题实战。
http://img2.58codes.com/2024/20001976ZOxC7BHyN3.jpg


关于作者: 网站小编

码农网专注IT技术教程资源分享平台,学习资源下载网站,58码农网包含计算机技术、网站程序源码下载、编程技术论坛、互联网资源下载等产品服务,提供原创、优质、完整内容的专业码农交流分享平台。

热门文章