Подтвердить что ты не робот

PyTorch: как использовать DataLoaders для пользовательских наборов данных

Как использовать torch.utils.data.Dataset и torch.utils.data.DataLoader для своих собственных данных (а не только torchvision.datasets)?

Есть ли способ использовать встроенный DataLoaders, который они используют в TorchVisionDatasets для использования в любом наборе данных?

4b9b3361

Ответ 1

Да, это возможно. Просто создайте объекты самостоятельно, например.

import torch.utils.data as data_utils

train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

где features и targets - тензоры. features должен быть двухмерным, т.е. матрицей, где каждая строка представляет один образец обучения, а targets может быть 1-D или 2-D, в зависимости от того, пытаетесь ли вы предсказать скаляр или вектор.

Надеюсь, что это поможет!


РЕДАКТИРОВАТЬ: ответ на вопрос @sarthak

В принципе да. Если вы создаете объект типа TensorData, то конструктор исследует, имеют ли первые размеры тензора признаков (который фактически называется data_tensor) и целевой тензор (называемый target_tensor) имеют одинаковую длину:

assert data_tensor.size(0) == target_tensor.size(0)

Однако, если вы хотите впоследствии поместить эти данные в нейронную сеть, вам нужно быть осторожным. Хотя уровни свертки работают над такими данными, как ваши, (я думаю), все остальные типы слоев ожидают, что данные будут представлены в матричной форме. Итак, если вы столкнулись с такой проблемой, то простым решением было бы преобразовать ваш 4D-набор данных (заданный как какой-то тензор, например FloatTensor) в матрицу, используя метод view. Для вашего набора данных 5000xnxnx3 это будет выглядеть так:

2d_dataset = 4d_dataset.view(5000, -1)

(Значение -1 сообщает PyTorch автоматически определять длину второго измерения.)

Ответ 2

Вы можете легко сделать это, расширяя класс data.Dataset. Согласно API все, что вам нужно сделать, это реализовать две функции: __getitem__ и __len__.

Затем вы можете обернуть набор данных с помощью DataLoader, как показано в API и в ответе @pho7.

Я думаю, что класс ImageFolder является справочным. Смотрите код здесь.

Ответ 3

В дополнение к user3693922 ответу и принятому ответу, которые соответственно ссылаются на "быстрый" пример документации PyTorch для создания пользовательских загрузчиков данных для пользовательских наборов данных и создания пользовательских загрузчик данных в "простейшем" случае, есть гораздо более подробное специальное официальное руководство по PyTorch о том, как создать собственный загрузчик данных с соответствующей предварительной обработкой: "написание пользовательских наборов данных, загрузчиков данных и преобразований" официальный PyTorch руководство