# 最初の100個のデータを使用
n_samples = 100

# 学習データを準備
X_train = datasets.MNIST(root='./data', train=True, download=True,
                         transform=transforms.Compose([transforms.ToTensor()]))

# 数字0と1だけを抽出
idx = np.append(np.where(X_train.targets == 0)[0][:n_samples], 
                np.where(X_train.targets == 1)[0][:n_samples])
# 入力データ
X_train.data = X_train.data[idx]
# 正解ラベル
X_train.targets = X_train.targets[idx]

# データのロード
train_loader = torch.utils.data.DataLoader(X_train, batch_size=1, shuffle=True)

# 例として画像を表示
n_samples_show = 6 # 表示するサンプル数
data_iter = iter(train_loader)
fig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))
# 6個のサンプルを表示します
while n_samples_show > 0:
    images, targets = data_iter.__next__()
    axes[n_samples_show - 1].imshow(images[0].numpy().squeeze(), cmap='gray')
    axes[n_samples_show - 1].set_xticks([])
    axes[n_samples_show - 1].set_yticks([])
    axes[n_samples_show - 1].set_title("Labeled: {}".format(targets.item()))
    n_samples_show -= 1