Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_AnnoyTripletGenerator():
neighbour_list = np.load('tests/data/test_knn_k4.npy')
iris = datasets.load_iris()
X = iris.data
batch_size = 32
data_generator = KnnTripletGenerator(X, neighbour_list,
batch_size=batch_size)
# Run generator thorugh one iteration of dataset and into the next
for i in range((X.shape[0] // batch_size) + 1):
batch = data_generator.__getitem__(i)
# Check that everything is the expected shape
assert isinstance(batch, tuple)
assert len(batch) == 2
assert len(batch[0]) == 3
assert len(batch[1]) <= batch_size
assert batch[0][0].shape[-1] == X.shape[-1]
def test_KnnTripletGenerator():
neighbour_list = np.load('tests/data/test_knn_k4.npy')
iris = datasets.load_iris()
X = iris.data
batch_size = 32
data_generator = KnnTripletGenerator(X, neighbour_list,
batch_size=batch_size)
# Run generator thorugh one iteration of dataset and into the next
for i in range((X.shape[0] // batch_size) + 1):
batch = data_generator.__getitem__(i)
# Check that everything is the expected shape
assert isinstance(batch, tuple)
assert len(batch) == 2
assert len(batch[0]) == 3
assert len(batch[1]) <= batch_size
assert batch[0][0].shape[-1] == X.shape[-1]
raise Exception('''k value greater than or equal to (num_rows - 1)
(k={}, rows={}). Lower k to a smaller
value.'''.format(k, X.shape[0]))
if batch_size > X.shape[0]:
raise Exception('''batch_size value larger than num_rows in dataset
(batch_size={}, rows={}). Lower batch_size to a
smaller value.'''.format(batch_size, X.shape[0]))
if Y is None:
if precompute:
if verbose > 0:
print('Extracting KNN from index')
neighbour_matrix = extract_knn(X, index_path, k=k,
search_k=search_k, verbose=verbose)
return KnnTripletGenerator(X, neighbour_matrix,
batch_size=batch_size)
else:
index = AnnoyIndex(X.shape[1], metric='angular')
index.load(index_path)
return AnnoyTripletGenerator(X, index, k=k,
batch_size=batch_size,
search_k=search_k)
else:
if precompute:
if verbose > 0:
print('Extracting KNN from index')
neighbour_matrix = extract_knn(X, index_path, k=k,
search_k=search_k, verbose=verbose)
return LabeledKnnTripletGenerator(X, Y, neighbour_matrix,
batch_size=batch_size)