Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@override(EvaluatorInterface)
def set_weights(self, weights):
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
@override(ModelV2)
def last_output(self):
return self.cur_instance.outputs
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {}
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicy.gradients(self, optimizer, loss)
@override(Aggregator)
def should_broadcast(self):
return self.num_sent_since_broadcast >= self.broadcast_interval
@override(Trainable)
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"]
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
logits = self._logits(features)
self._cur_value = self._value_branch(features).squeeze(1)
return logits, state
@override(TorchModelV2)
def forward(self, input_dict, hidden_state, seq_lens):
x = F.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, [h]
@override(ActionDistribution)
def sample(self):
return self.dist.sample()