Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions forecasting_tools/data_models/binary_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from forecasting_tools.data_models.questions import BinaryQuestion
from forecasting_tools.helpers.metaculus_client import MetaculusClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,19 +51,24 @@ def validate_prediction(cls: BinaryReport, v: float) -> float:
raise ValueError("Prediction must be between 0 and 1")
return v

async def publish_report_to_metaculus(self) -> None:
from forecasting_tools.helpers.metaculus_api import MetaculusApi
async def publish_report_to_metaculus(
self, metaculus_client: MetaculusClient | None = None
) -> None:
from forecasting_tools.helpers.metaculus_client import MetaculusClient

metaculus_client = metaculus_client or MetaculusClient()
if self.question.id_of_question is None:
raise ValueError("Question ID is None")
if self.question.id_of_post is None:
raise ValueError(
"Publishing to Metaculus requires a post ID for the question"
)
MetaculusApi.post_binary_question_prediction(
metaculus_client.post_binary_question_prediction(
self.question.id_of_question, self.prediction
)
MetaculusApi.post_question_comment(self.question.id_of_post, self.explanation)
metaculus_client.post_question_comment(
self.question.id_of_post, self.explanation
)

@classmethod
async def aggregate_predictions(
Expand Down
20 changes: 13 additions & 7 deletions forecasting_tools/data_models/conditional_report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from pydantic import computed_field

from forecasting_tools.data_models.conditional_models import (
Expand All @@ -14,6 +16,9 @@
)
from forecasting_tools.util.misc import clean_indents

if TYPE_CHECKING:
from forecasting_tools.helpers.metaculus_client import MetaculusClient


class ConditionalReport(ForecastReport):
question: ConditionalQuestion
Expand Down Expand Up @@ -128,10 +133,11 @@ def make_readable_prediction(cls, prediction: ConditionalPrediction) -> str:
"""
)

async def publish_report_to_metaculus(self) -> None:
# if not isinstance(self.parent_report.prediction, PredictionAffirmed):
# await self.parent_report.publish_report_to_metaculus()
# if not isinstance(self.child_report.prediction, PredictionAffirmed):
# await self.child_report.publish_report_to_metaculus()
await self.yes_report.publish_report_to_metaculus()
await self.no_report.publish_report_to_metaculus()
async def publish_report_to_metaculus(
self, metaculus_client: MetaculusClient | None = None
) -> None:
from forecasting_tools.helpers.metaculus_client import MetaculusClient

metaculus_client = metaculus_client or MetaculusClient()
await self.yes_report.publish_report_to_metaculus(metaculus_client)
await self.no_report.publish_report_to_metaculus(metaculus_client)
6 changes: 4 additions & 2 deletions forecasting_tools/data_models/forecast_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from forecasting_tools.data_models.questions import MetaculusQuestion

from forecasting_tools.helpers.metaculus_client import MetaculusClient

logger = logging.getLogger(__name__)
T = TypeVar("T")
Expand Down Expand Up @@ -128,7 +128,9 @@ def make_readable_prediction(cls, prediction: Any) -> str:
raise NotImplementedError("Subclass must implement this abstract method")

@abstractmethod
async def publish_report_to_metaculus(self) -> None:
async def publish_report_to_metaculus(
self, metaculus_client: MetaculusClient | None = None
) -> None:
raise NotImplementedError("Subclass must implement this abstract method")

def _get_and_validate_section(self, index: int, expected_word: str) -> MarkdownTree:
Expand Down
18 changes: 14 additions & 4 deletions forecasting_tools/data_models/multiple_choice_report.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from pydantic import BaseModel, Field, model_validator

from forecasting_tools.data_models.forecast_report import ForecastReport
from forecasting_tools.data_models.questions import MultipleChoiceQuestion

if TYPE_CHECKING:
from forecasting_tools.helpers.metaculus_client import MetaculusClient

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -95,8 +99,12 @@ def expected_baseline_score(self) -> float | None:
def community_prediction(self) -> PredictedOptionList | None:
raise NotImplementedError("Not implemented")

async def publish_report_to_metaculus(self) -> None:
from forecasting_tools.helpers.metaculus_api import MetaculusApi
async def publish_report_to_metaculus(
self, metaculus_client: MetaculusClient | None = None
) -> None:
from forecasting_tools.helpers.metaculus_client import MetaculusClient

metaculus_client = metaculus_client or MetaculusClient()

if self.question.id_of_question is None:
raise ValueError("Question ID is None")
Expand All @@ -108,10 +116,12 @@ async def publish_report_to_metaculus(self) -> None:
raise ValueError(
"Publishing to Metaculus requires a post ID for the question"
)
MetaculusApi.post_multiple_choice_question_prediction(
metaculus_client.post_multiple_choice_question_prediction(
self.question.id_of_question, options_with_probabilities
)
MetaculusApi.post_question_comment(self.question.id_of_post, self.explanation)
metaculus_client.post_question_comment(
self.question.id_of_post, self.explanation
)

@classmethod
async def aggregate_predictions(
Expand Down
12 changes: 8 additions & 4 deletions forecasting_tools/data_models/numeric_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
DiscreteQuestion,
NumericQuestion,
)

from forecasting_tools.helpers.metaculus_client import MetaculusClient
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -650,9 +650,13 @@ def make_readable_prediction(cls, prediction: NumericDistribution) -> str:
readable += f"- {percentile.percentile:.2%} chance of value below {formatted_value}\n"
return readable

async def publish_report_to_metaculus(self) -> None:
async def publish_report_to_metaculus(
self, metaculus_client: MetaculusClient | None = None
) -> None:
from forecasting_tools.helpers.metaculus_client import MetaculusClient

metaculus_client = metaculus_client or MetaculusClient()

if self.question.id_of_question is None:
raise ValueError("Publishing to Metaculus requires a question ID")

Expand All @@ -671,10 +675,10 @@ async def publish_report_to_metaculus(self) -> None:
percentile.percentile for percentile in prediction.get_cdf()
]

MetaculusClient().post_numeric_question_prediction(
metaculus_client.post_numeric_question_prediction(
self.question.id_of_question, cdf_probabilities
)
MetaculusClient().post_question_comment(
metaculus_client.post_question_comment(
self.question.id_of_post, self.explanation
)

Expand Down
16 changes: 11 additions & 5 deletions forecasting_tools/forecast_bots/forecast_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
MultipleChoiceQuestion,
NumericQuestion,
)
from forecasting_tools.helpers.metaculus_api import MetaculusApi
from forecasting_tools.helpers.metaculus_client import MetaculusClient
from forecasting_tools.util.misc import clean_indents

T = TypeVar("T")
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
parameters_to_exclude_from_config_dict: list[str] | None = None,
extra_metadata_in_explanation: bool = False,
required_successful_predictions: float = 0.5,
metaculus_client: MetaculusClient | None = None,
) -> None:
assert (
research_reports_per_question > 0
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
self._note_pads: list[Notepad] = []
self._note_pad_lock = asyncio.Lock()
self._llms = llms or self._llm_config_defaults()
self.metaculus_client = metaculus_client or MetaculusClient()

for purpose, llm in self._llm_config_defaults().items():
if purpose not in self._llms:
Expand Down Expand Up @@ -162,7 +164,9 @@ async def forecast_on_tournament(
tournament_id: int | str,
return_exceptions: bool = False,
) -> list[ForecastReport] | list[ForecastReport | BaseException]:
questions = MetaculusApi.get_all_open_questions_from_tournament(tournament_id)
questions = self.metaculus_client.get_all_open_questions_from_tournament(
tournament_id
)
return await self.forecast_questions(questions, return_exceptions)

@overload
Expand Down Expand Up @@ -412,7 +416,9 @@ async def _run_individual_question(
errors=all_errors,
)
if self.publish_reports_to_metaculus:
await report.publish_report_to_metaculus()
await report.publish_report_to_metaculus(
metaculus_client=self.metaculus_client
)
await self._remove_notepad(question)
return report

Expand Down Expand Up @@ -548,8 +554,8 @@ async def _run_forecast_on_conditional(
full_prediction = ConditionalPrediction(
parent=PredictionAffirmed(),
child=PredictionAffirmed(),
prediction_yes=yes_info.prediction_value,
prediction_no=no_info.prediction_value,
prediction_yes=yes_info.prediction_value, # type: ignore
prediction_no=no_info.prediction_value, # type: ignore
)
return ReasonedPrediction(
reasoning=full_reasoning, prediction_value=full_prediction
Expand Down