Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return
for i in range (0,self.agents):
self.queue_put_stoppable(self.q, sum_r[i])
self.queue_put_stoppable(self.q_dist, dist[i])
q = queue.Queue()
q_dist = queue.Queue()
threads = [Worker(f, q, q_dist,agents=agents) for f in predictors]
# start all workers
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
dist_stat = StatCounter()
# show progress bar w/ tqdm
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
r = q.get()
stat.feed(r)
dist = q_dist.get()
dist_stat.feed(dist)
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
k.join()
while q.qsize():
r = q.get()
def reset_stat(self):
""" Reset all statistics counter"""
self.stats = defaultdict(list)
self.num_games = StatCounter()
self.num_success = StatCounter()
def recursive():
import timeit
env = Pyenv()
st = StatCounter()
for i in range(1):
env.reset()
env.prepare()
# print(env.get_handcards())
cards = env.get_handcards()[:15]
cards = ['J', '10', '10', '7', '7', '6']
# last_cards = ['3', '3']
mask = get_mask_onehot60(cards, action_space, None).reshape(len(action_space), 15, 4).sum(-1).astype(np.uint8)
valid = mask.sum(-1) > 0
cards_target = Card.char2onehot60(cards).reshape(-1, 4).sum(-1).astype(np.uint8)
t1 = timeit.default_timer()
print(cards_target)
print(mask[valid])
combs = get_combinations_recursive(mask[valid, :], cards_target)
print(combs)
setattr(self, k, v)
self.agent_name = agent_name
self.exploration = init_exploration
self.num_actions = num_actions
logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape)
# self._current_ob, self._action_space = self.get_state_and_action_spaces()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.exploration = init_exploration
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._player_distError = StatCounter()
logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape)
self.player.reset()
self.player.prepare()
self._comb_mask = True
self._fine_mask = None
self._current_ob, self._action_space = self.get_state_and_action_spaces()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
# self.eval_episode = int(self.eval_episode * 0.94)
def _trigger_epoch(self):
t = time.time()
farmer_win_rate = eval_with_funcs(
self.pred_funcs, self.eval_episode, self.get_player_fn, verbose=False)
t = time.time() - t
if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put_scalar('farmer win rate', farmer_win_rate)
self.trainer.monitors.put_scalar('lord win rate', 1 - farmer_win_rate)
if __name__ == '__main__':
env = Env()
stat = StatCounter()
init_cards = np.arange(15)
# init_cards = np.append(init_cards[::4], init_cards[1::4])
for _ in range(1000):
env.reset()
env.prepare_manual(init_cards)
r = 0
while r == 0:
_, r, _ = env.step_auto()
stat.feed(int(r < 0))
print('lord win rate: {}'.format(stat.average))
with self.default_sess():
player = get_player_fn()
while not self.stopped():
try:
val = play_one_episode(player, self.func)
except RuntimeError:
return
self.queue_put_stoppable(self.q, val)
q = queue.Queue()
threads = [Worker(f, q) for f in predictors]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
def fetch():
val = q.get()
stat.feed(val)
if verbose:
if val > 0:
logger.info("farmer wins")
else:
logger.info("lord wins")
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
with self.default_sess():
player = get_player_fn()
while not self.stopped():
try:
val = play_one_episode(player, self.func)
except RuntimeError:
return
self.queue_put_stoppable(self.q, val)
q = queue.Queue()
threads = [Worker(f, q) for f in predictors]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
def fetch():
val = q.get()
stat.feed(val)
if verbose:
if val > 0:
logger.info("farmer wins")
else:
logger.info("lord wins")
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
player = get_player_fn()
while not self.stopped():
try:
stats = play_one_episode(player, self.func)
except RuntimeError:
return
scores = [stat.average if stat.count > 0 else -1 for stat in stats]
self.queue_put_stoppable(self.q, scores)
q = queue.Queue()
threads = [Worker(f, q) for f in predictors]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stats = [StatCounter() for _ in range(7)]
def fetch():
scores = q.get()
for i, score in enumerate(scores):
if scores[i] >= 0:
stats[i].feed(scores[i])
accs = [stat.average if stat.count > 0 else 0 for stat in stats]
if verbose:
logger.info("passive decision accuracy: {}\n"
"passive bomb accuracy: {}\n"
"passive response accuracy: {}\n"
"active decision accuracy: {}\n"
"active response accuracy: {}\n"
"active sequence accuracy: {}\n"
"minor response accuracy: {}\n".format(accs[0], accs[1], accs[2], accs[3], accs[4], accs[5], accs[6]))