How to use the torchmeta.transforms.ClassSplitter function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / toy / helpers.py View on Github external
'Ignoring the argument `shots`.', stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError('The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                    test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Harmonic(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / toy / helpers.py View on Github external
'Ignoring the argument `shots`.', stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError('The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                    test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Sinusoid(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
`datasets.cifar100.CIFARFS` : Meta-dataset for the CIFAR-FS dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = ToTensor()
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = CIFARFS(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
`datasets.TieredImagenet` : Meta-dataset for the Tiered-Imagenet dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(84), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = TieredImagenet(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
`datasets.MiniImagenet` : Meta-dataset for the Mini-Imagenet dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(84), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = MiniImagenet(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(28), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]
    if test_shots is None:
        test_shots = shots

    dataset = Omniglot(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        image_size = 84
        kwargs['transform'] = Compose([
            Resize(int(image_size * 1.5)),
            CenterCrop(image_size),
            ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = CUB(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset