Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""target function for sub-processing to host a model
Parameters
----------
addr: socket address
sample_buffer_capacity: int
the maximum number of samples (s,r,a,s') to collect in a game round
RLModel: BaseModel
the RL algorithm class
args: dict
arguments to RLModel
"""
import magent.utility
model = RLModel(**model_args)
sample_buffer = magent.utility.EpisodesBuffer(capacity=sample_buffer_capacity)
conn = multiprocessing.connection.Client(addr)
while True:
cmd = conn.recv()
if cmd[0] == 'act':
policy = cmd[1]
eps = cmd[2]
array_info = cmd[3]
view, feature, ids = NDArrayPackage(array_info).recv_from(conn)
obs = (view, feature)
acts = model.infer_action(obs, ids, policy=policy, eps=eps)
package = NDArrayPackage(acts)
conn.send(package.info)
# init env
env = magent.GridWorld(load_config(size=args.map_size))
env.set_render_dir("build/render")
handles = env.get_handles()
food_handle = handles[0]
player_handles = handles[1:]
# sample eval observation set
eval_obs = None
if args.eval:
print("sample eval set...")
env.reset()
generate_map(env, args.map_size, food_handle, player_handles)
eval_obs = magent.utility.sample_observation(env, player_handles, 0, 2048, 500)
# load models
models = [
RLModel(env, player_handles[0], args.name,
batch_size=512, memory_size=2 ** 19, target_update=1000,
train_freq=4, eval_obs=eval_obs)
]
# load saved model
save_dir = "save_model"
if args.load_from is not None:
start_from = args.load_from
print("load models...")
for model in models:
model.load(save_dir, start_from)
else: