Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_train_torch_cnn():
batch_size = 50
datadir = os.path.dirname(__file__)+'/data'
fname = datadir+"/small.types"
molgrid.set_random_seed(0)
torch.manual_seed(0)
np.random.seed(0)
class Net(nn.Module):
def __init__(self, dims):
super(Net, self).__init__()
self.pool0 = nn.MaxPool3d(2)
self.conv1 = nn.Conv3d(dims[0], 32, kernel_size=3, padding=1)
self.pool1 = nn.MaxPool3d(2)
self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
self.pool2 = nn.MaxPool3d(2)
self.conv3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
self.last_layer_size = dims[1]//8 * dims[2]//8 * dims[3]//8 * 128
self.fc1 = nn.Linear(self.last_layer_size, 2)
def test_random_transform():
from molgrid import Transform
molgrid.set_random_seed(0)
c1 = molgrid.float3(0,0,0);
c2 = molgrid.float3(0,0,1);
t1 = Transform(c1, 4.0, True)
t2 = Transform(c2, 4.0, True)
nrt1 = Transform (c1)
nrt2 = Transform (c2)
t = Transform()
molgrid.set_random_seed(0) # reset, should get same sample
t3 = Transform(c1, 4.0, True);
neqQ(t1.get_quaternion(),t2.get_quaternion());
neqQ(t1.get_quaternion(),nrt1.get_quaternion());
def test_random_transform():
from molgrid import Transform
molgrid.set_random_seed(0)
c1 = molgrid.float3(0,0,0);
c2 = molgrid.float3(0,0,1);
t1 = Transform(c1, 4.0, True)
t2 = Transform(c2, 4.0, True)
nrt1 = Transform (c1)
nrt2 = Transform (c2)
t = Transform()
molgrid.set_random_seed(0) # reset, should get same sample
t3 = Transform(c1, 4.0, True);
neqQ(t1.get_quaternion(),t2.get_quaternion());
neqQ(t1.get_quaternion(),nrt1.get_quaternion());
eqQ(t1.get_quaternion(),t3.get_quaternion());
eqQ(nrt1.get_quaternion(),nrt2.get_quaternion());
assert tup(t1.get_translation()) != tup(t2.get_translation())
assert tup(t1.get_translation()) != tup(nrt1.get_translation())
assert tup(t1.get_translation()) == tup(t3.get_translation())
assert tup(nrt1.get_translation()) == tup(nrt2.get_translation())
assert tup(c1) == tup(t1.get_rotation_center())
assert tup(c2) == tup(t2.get_rotation_center())
assert tup(c1) == tup(nrt1.get_rotation_center())