diff --git a/games/referencegame/master.py b/games/referencegame/master.py index ae4cc0351cc9ea8c3b97fdde9033abf8d9a32be1..6a6030176a3d2c46449f791a91d9306e660a1025 100644 --- a/games/referencegame/master.py +++ b/games/referencegame/master.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Dict from backends import Model from clemgame import file_utils from clemgame import metrics -from clemgame.clemgame import GameMaster, GameBenchmark +from clemgame.clemgame import GameMaster, GameBenchmark, GameScorer from clemgame import get_logger from games.referencegame.game import ReferenceGame import re @@ -11,6 +11,9 @@ import math GAME_NAME = "referencegame" +PLAYER_A_PATTERN = r'^Expression:\s*(.+)\n*(.+)*$' +PLAYER_B_PATTERN= r"^Answer:\s*(?!.*\b(?:first|second|third|First|Second|Third)\b.*\b(?:first|second|third)\b).*\b(?:first grid|second grid|first|second|third grid|third|First grid|Second grid|Third grid)\b.*$" + logger = get_logger(__name__) @@ -20,8 +23,6 @@ class ReferenceGameMaster(GameMaster): super().__init__(GAME_NAME, experiment, player_models) self.experiment = experiment self.game = None - self.player_a_pattern = r'^Expression:\s*(.+)\n*(.+)*$' - self.player_b_pattern = r"^Answer:\s*(?!.*\b(?:first|second|third|First|Second|Third)\b.*\b(?:first|second|third)\b).*\b(?:first grid|second grid|first|second|third grid|third|First grid|Second grid|Third grid)\b.*$" self.request_count = 0 self.parsed_request_count = 0 self.violated_request_count = 0 @@ -118,7 +119,7 @@ class ReferenceGameMaster(GameMaster): self.request_count += 1 # check if the Player 2 message matches the rule => grid - if re.match(self.player_b_pattern, player_2_response_text): + if re.match(PLAYER_B_PATTERN, player_2_response_text): self.parsed_request_count += 1 action = {'type': 'parse', 'content': player_2_response_text, @@ -134,6 +135,13 @@ class ReferenceGameMaster(GameMaster): self.violated_request_count += 1 self.aborted_ratio = 1 + +class ReferenceGameScorer(GameScorer): + + def __init__(self, experiment: Dict, game_instance: Dict): + super().__init__(GAME_NAME, experiment, game_instance) + self.target_grid_name = game_instance["target_grid_name"] + def compute_scores(self, episode_interactions: Dict) -> None: success = 0 @@ -186,13 +194,13 @@ class ReferenceGameMaster(GameMaster): episode_request_count += 1 # check if the Player 2 message matches the rule -> start "Answer: ..." - match = re.compile(self.player_b_pattern).match(player_2_message) + match = re.compile(PLAYER_B_PATTERN).match(player_2_message) if match: turn_parsed_request_count += 1 episode_parsed_request_count += 1 # check if the target grid number matches the output from Player 2 - if self.game.target_grid_name.lower() in player_2_message.replace('Answer:', '').lower(): + if self.target_grid_name.lower() in player_2_message.replace('Answer:', '').lower(): success = 1 else: lost_count = 1 @@ -202,7 +210,6 @@ class ReferenceGameMaster(GameMaster): aborted = True break - # log the Player 1 - message length expression_length = len(player_1_message.replace('Expression:', '').strip()) self.log_turn_score(t_index, 'Generated Expression Length', expression_length) @@ -274,16 +281,6 @@ class ReferenceGameMaster(GameMaster): self.log_episode_score(metrics.METRIC_REQUEST_SUCCESS, 0) - - - - - def _get_recorded_turns(self, records: Dict) -> List[int]: - return list(range(len(records["turns"]))) - - - - class ReferenceGameBenchmark(GameBenchmark): def __init__(self): @@ -295,6 +292,8 @@ class ReferenceGameBenchmark(GameBenchmark): def create_game_master(self, experiment: Dict, player_models: List[Model]) -> GameMaster: return ReferenceGameMaster(experiment, player_models) + def create_game_scorer(self, experiment: Dict, game_instance: Dict) -> GameScorer: + return ReferenceGameScorer(experiment, game_instance) def main(): # select one instance