Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion mastery_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ def add_common_args(parser: argparse.ArgumentParser):
default=False,
help="Truncate runs to the end of the shortest simulation",
)
parser.add_argument(
"--extend",
action="store_true",
default=False,
help="Extend shorter runs with their average results",
)
parser.add_argument(
"--crop",
action="store_true",
Expand Down Expand Up @@ -1583,13 +1589,54 @@ def normalize_sims_vs_stone_cost(
)


def extend_sims(
sim_results: list[tuple[Simulation, SimulationRunResult | None]],
) -> Iterator[tuple[Simulation, SimulationRunResult | None]]:
max_time = max(
run_result.wave_results[-1].elapsed_time
for _, run_result in sim_results
if run_result is not None
)
for sim, run_result in sim_results:
if run_result is None:
yield sim, None
continue
wave_time = WAVE_DURATION + WAVE_COOLDOWN
total_rewards = run_result.wave_results[-1].cumulative_rewards
total_events = run_result.wave_results[-1].cumulative_events
duration = run_result.wave_results[-1].elapsed_time
avg_rewards = Rewards(
coins=total_rewards.coins / duration * wave_time,
elite_cells=total_rewards.elite_cells / duration * wave_time,
module_shards=total_rewards.module_shards / duration * wave_time,
reroll_shards=total_rewards.reroll_shards / duration * wave_time,
)
extended_waves = copy.deepcopy(run_result.wave_results)
while extended_waves[-1].elapsed_time < max_time:
last_wave = copy.deepcopy(extended_waves[-1])
last_wave.cumulative_rewards.coins += avg_rewards.coins
last_wave.cumulative_rewards.elite_cells += avg_rewards.elite_cells
last_wave.cumulative_rewards.module_shards += avg_rewards.module_shards
last_wave.cumulative_rewards.reroll_shards += avg_rewards.reroll_shards
last_wave.elapsed_time += wave_time

extended_waves.append(last_wave)

total = reward_value(sim, extended_waves[-1].cumulative_rewards)
yield sim, dataclasses.replace(
run_result, wave_results=extended_waves, total=total
)


def normalize_sims(
args: argparse.Namespace,
sim_results: list[tuple[Simulation, SimulationRunResult | None]],
baseline_sim_name: str,
) -> list[tuple[Simulation, SimulationRunResult | None]]:
if args.truncate:
sim_results = list(truncate_sims_to_shortest(sim_results))
elif args.extend:
sim_results = list(extend_sims(sim_results))
if args.elapsed:
sim_results = list(normalize_sims_vs_elapsed(sim_results))
if args.relative:
Expand Down Expand Up @@ -2041,6 +2088,8 @@ def main():
parser.error("--orb-hits must be between 0.0 and 1.0")
if args.relative and args.difference:
parser.error("--relative and --difference are mutually exclusive")
if args.truncate and args.extend:
parser.error("--truncate and --extend are mutually exclusive")
if not args.relative and args.roi:
parser.error("--roi can only be used with --relative")

Expand All @@ -2061,4 +2110,4 @@ def main():
try:
main()
except BrokenPipeError:
pass
pass