Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
A Closer Look at Few-shot Classification. International Conference on
Learning Representations (https://openreview.net/forum?id=HkxLXnAcFQ)
"""
def __init__(self, root, num_classes_per_task=None, meta_train=False,
meta_val=False, meta_test=False, meta_split=None,
transform=None, target_transform=None, dataset_transform=None,
class_augmentations=None, download=False):
dataset = CUBClassDataset(root, meta_train=meta_train, meta_val=meta_val,
meta_test=meta_test, meta_split=meta_split, transform=transform,
class_augmentations=class_augmentations, download=download)
super(CUB, self).__init__(dataset, num_classes_per_task,
target_transform=target_transform, dataset_transform=dataset_transform)
class CUBClassDataset(ClassDataset):
folder = 'cub'
download_url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
tgz_md5 = '97eceeb196236b17998738112f37df78'
image_folder = 'CUB_200_2011/images'
filename = '{0}_data.hdf5'
filename_labels = '{0}_labels.json'
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
meta_split=None, transform=None, class_augmentations=None,
download=False):
super(CUBClassDataset, self).__init__(meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
class_augmentations=class_augmentations)
self.root = os.path.join(os.path.expanduser(root), self.folder)
.. [2] Ravi, S. and Larochelle, H. (2016). Optimization as a Model for
Few-Shot Learning. (https://openreview.net/forum?id=rJY0-Kcll)
"""
def __init__(self, root, num_classes_per_task=None, meta_train=False,
meta_val=False, meta_test=False, meta_split=None,
transform=None, target_transform=None, dataset_transform=None,
class_augmentations=None, download=False):
dataset = MiniImagenetClassDataset(root, meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
transform=transform, class_augmentations=class_augmentations,
download=download)
super(MiniImagenet, self).__init__(dataset, num_classes_per_task,
target_transform=target_transform, dataset_transform=dataset_transform)
class MiniImagenetClassDataset(ClassDataset):
folder = 'miniimagenet'
# Google Drive ID from https://github.com/renmengye/few-shot-ssl-public
gdrive_id = '16V_ZlkW4SsnNDtnGmaBRq2OoPmUOc5mY'
gz_filename = 'mini-imagenet.tar.gz'
gz_md5 = 'b38f1eb4251fb9459ecc8e7febf9b2eb'
pkl_filename = 'mini-imagenet-cache-{0}.pkl'
filename = '{0}_data.hdf5'
filename_labels = '{0}_labels.json'
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
meta_split=None, transform=None, class_augmentations=None,
download=False):
super(MiniImagenetClassDataset, self).__init__(meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
class_augmentations=class_augmentations)
import numpy as np
import os
import json
import h5py
from PIL import Image
from torchvision.datasets.utils import check_integrity, download_url
from torchmeta.utils.data import Dataset, ClassDataset
class CIFAR100ClassDataset(ClassDataset):
folder = 'cifar100'
subfolder = None
download_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
gz_folder = 'cifar-100-python'
gz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
files_md5 = {
'train': '16019d7e3df5f24257cddd939b257f8d',
'test': 'f0ef6b0ae62326f3e7ffdfab6717acfc',
'meta': '7973b15100ade9c7d40fb424638fde48'
}
filename = 'data.hdf5'
filename_labels = '{0}_labels.json'
filename_fine_names = 'fine_names.json'
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
for semi-supervised few-shot classification. International Conference
on Learning Representations. (https://arxiv.org/abs/1803.00676)
"""
def __init__(self, root, num_classes_per_task=None, meta_train=False,
meta_val=False, meta_test=False, meta_split=None,
transform=None, target_transform=None, dataset_transform=None,
class_augmentations=None, download=False):
dataset = TieredImagenetClassDataset(root, meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
transform=transform, class_augmentations=class_augmentations,
download=download)
super(TieredImagenet, self).__init__(dataset, num_classes_per_task,
target_transform=target_transform, dataset_transform=dataset_transform)
class TieredImagenetClassDataset(ClassDataset):
folder = 'tieredimagenet'
# Google Drive ID from https://github.com/renmengye/few-shot-ssl-public
gdrive_id = '1g1aIDy2Ar_MViF2gDXFYDBTR-HYecV07'
tar_filename = 'tiered-imagenet.tar'
tar_md5 = 'e07e811b9f29362d159a9edd0d838c62'
tar_folder = 'tiered-imagenet'
filename = '{0}_data.hdf5'
filename_labels = '{0}_labels.json'
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
meta_split=None, transform=None, class_augmentations=None,
download=False):
super(TieredImagenetClassDataset, self).__init__(meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
class_augmentations=class_augmentations)
Information Processing Systems (pp. 3630-3638) (https://arxiv.org/abs/1606.04080)
"""
def __init__(self, root, num_classes_per_task=None, meta_train=False,
meta_val=False, meta_test=False, meta_split=None,
use_vinyals_split=True, transform=None, target_transform=None,
dataset_transform=None, class_augmentations=None, download=False):
dataset = OmniglotClassDataset(root, meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test,
use_vinyals_split=use_vinyals_split, transform=transform,
meta_split=meta_split, class_augmentations=class_augmentations,
download=download)
super(Omniglot, self).__init__(dataset, num_classes_per_task,
target_transform=target_transform, dataset_transform=dataset_transform)
class OmniglotClassDataset(ClassDataset):
folder = 'omniglot'
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}
filename = 'data.hdf5'
filename_labels = '{0}{1}_labels.json'
def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
meta_split=None, use_vinyals_split=True, transform=None,
class_augmentations=None, download=False):
super(OmniglotClassDataset, self).__init__(meta_train=meta_train,
meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
class_augmentations=class_augmentations)