Skip to content
Snippets Groups Projects
Unverified Commit ec4d17d8 authored by Philipp's avatar Philipp Committed by GitHub
Browse files

add cli option to specify instance file name (#52)

parent 8fc8ea05
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,8 @@ def list_games():
stdout_logger.info(" Game: %s -> %s", game.name, game.get_description())
def run(game_name: str, model_specs: List[backends.ModelSpec], gen_args: Dict, experiment_name: str = None):
def run(game_name: str, model_specs: List[backends.ModelSpec], gen_args: Dict,
experiment_name: str = None, instances_name: str = None):
if experiment_name:
logger.info("Only running experiment: %s", experiment_name)
try:
......@@ -35,7 +36,7 @@ def run(game_name: str, model_specs: List[backends.ModelSpec], gen_args: Dict, e
model = backends.get_model_for(model_spec)
model.set_gen_args(**gen_args) # todo make this somehow available in generate method?
player_models.append(model)
benchmark = load_benchmark(game_name)
benchmark = load_benchmark(game_name, instances_name=instances_name)
logger.info("Running benchmark for '%s' (models=%s)", game_name,
player_models if player_models is not None else "see experiment configs")
if experiment_name:
......
......@@ -593,9 +593,10 @@ class GameBenchmark(GameResourceLocator):
"""
raise NotImplementedError()
def setup(self):
# For now, we assume a single instances.json
self.instances = self.load_json("in/instances.json")
def setup(self, instances_name: str = None):
if instances_name is None:
instances_name = "instances"
self.instances = self.load_json(f"in/{instances_name}")
def build_transcripts(self):
results_root = file_utils.results_root()
......@@ -907,10 +908,10 @@ def load_benchmarks(do_setup: bool = True) -> List[GameBenchmark]:
return game_benchmarks
def load_benchmark(game_name: str, do_setup: bool = True) -> GameBenchmark:
def load_benchmark(game_name: str, do_setup: bool = True, instances_name: str = None) -> GameBenchmark:
gm = find_benchmark(game_name)
if do_setup:
gm.setup()
gm.setup(instances_name)
return gm
......
......@@ -59,7 +59,8 @@ def main(args: argparse.Namespace):
benchmark.run(args.game,
model_specs=read_model_specs(args.models),
gen_args=read_gen_args(args),
experiment_name=args.experiment_name)
experiment_name=args.experiment_name,
instances_name=args.instances_name)
if args.command_name == "score":
benchmark.score(args.game, experiment_name=args.experiment_name)
if args.command_name == "transcribe":
......@@ -96,6 +97,8 @@ if __name__ == "__main__":
help="Specify the maximum number of tokens to be generated per turn (except for cohere). "
"Be careful with high values which might lead to exceed your API token limits."
"Default: 100.")
run_parser.add_argument("-i", "--instances_name", type=str, default="instances",
required=True, help="The instances file name (.json suffix will be added automatically.")
score_parser = sub_parsers.add_parser("score")
score_parser.add_argument("-e", "--experiment_name", type=str,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment