Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_model():
return Sequential([
Conv2d(6, kernel_size=5, input_shape=[32, 32, 3]),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
Conv2d(16, kernel_size=5),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
Rearrange('b c h w -> b (c h w)'),
Linear(120),
ReLU(),
Linear(84),
ReLU(),
Linear(10),
])
class Reduce(ReduceMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape
def call(self, inputs):
return self._apply_recipe(inputs)
def get_config(self):
return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths}
keras_custom_objects = {Rearrange.__name__: Rearrange, Reduce.__name__: Reduce}