Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
380 changes: 264 additions & 116 deletions README.ipynb

Large diffs are not rendered by default.

315 changes: 207 additions & 108 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from forecasting_tools.ai_models.ai_utils.ai_misc import clean_indents
from forecasting_tools.ai_models.gpt4o import Gpt4o
from forecasting_tools.ai_models.deprecated_model_classes.gpt4o import Gpt4o
from forecasting_tools.ai_models.resource_managers.monetary_cost_manager import (
MonetaryCostManager,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from forecasting_tools.ai_models.ai_utils.ai_misc import clean_indents
from forecasting_tools.ai_models.gpt4o import Gpt4o
from forecasting_tools.ai_models.deprecated_model_classes.gpt4o import Gpt4o
from forecasting_tools.forecasting.sub_question_researchers.base_rate_researcher import (
BaseRateResearcher,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
)
from code_tests.unit_tests.test_ai_models.models_to_test import ModelsToTest
from code_tests.utilities_for_tests import coroutine_testing
from forecasting_tools.ai_models.gpto1preview import GptO1Preview
from forecasting_tools.ai_models.metaculus4o import Gpt4oMetaculusProxy
from forecasting_tools.ai_models.deprecated_model_classes.gpto1preview import (
GptO1Preview,
)
from forecasting_tools.ai_models.deprecated_model_classes.metaculus4o import (
Gpt4oMetaculusProxy,
)
from forecasting_tools.ai_models.model_interfaces.ai_model import AiModel

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
GeneralLlmInstancesToTest,
ModelTest,
)
from forecasting_tools.ai_models.general_llm import GeneralLlm


@pytest.mark.parametrize(
Expand All @@ -19,3 +20,14 @@ def test_general_llm_instances_run(
model_input = test.model_input
response = asyncio.run(model.invoke(model_input))
assert response is not None, "Response is None"


def test_timeout_works() -> None:
model = GeneralLlm(model="gpt-4o", timeout=0.1)
model_input = "Hello, world!"
with pytest.raises(Exception):
asyncio.run(model.invoke(model_input))

model = GeneralLlm(model="gpt-4o-mini", timeout=50)
response = asyncio.run(model.invoke(model_input))
assert response is not None, "Response is None"
11 changes: 6 additions & 5 deletions code_tests/low_cost_or_live_api_tests/test_metaculus_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@


def test_get_binary_question_type_from_id() -> None:
question_id = DataOrganizer.get_example_post_id_for_question_type(
# Test question w/ <1% probability: https://www.metaculus.com/questions/578/human-extinction-by-2100/
post_id = DataOrganizer.get_example_post_id_for_question_type(
BinaryQuestion
)
question = MetaculusApi.get_question_by_post_id(question_id)
question = MetaculusApi.get_question_by_post_id(post_id)
assert isinstance(question, BinaryQuestion)
assert question_id == question.id_of_post
assert post_id == question.id_of_post
assert question.community_prediction_at_access_time is not None
assert question.community_prediction_at_access_time == pytest.approx(0.01)
assert question.community_prediction_at_access_time <= 0.01
assert question.state == QuestionState.OPEN
assert_basic_question_attributes_not_none(question, question_id)
assert_basic_question_attributes_not_none(question, post_id)


def test_get_numeric_question_type_from_id() -> None:
Expand Down
47 changes: 30 additions & 17 deletions code_tests/low_cost_or_live_api_tests/test_smart_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,49 @@

import pytest

from forecasting_tools.ai_models.resource_managers.monetary_cost_manager import (
MonetaryCostManager,
)
from forecasting_tools.ai_models.general_llm import GeneralLlm
from forecasting_tools.forecasting.helpers.smart_searcher import SmartSearcher

logger = logging.getLogger(__name__)


async def test_ask_question_basic() -> None:
num_searches_to_run = 1
num_sites_per_search = 3
searcher = SmartSearcher(
include_works_cited_list=True,
num_searches_to_run=num_searches_to_run,
num_sites_per_search=num_sites_per_search,
num_searches_to_run=1,
num_sites_per_search=3,
)
question = "What is the recent news on SpaceX?"
report = await searcher.invoke(question)
logger.info(f"Report:\n{report}")
validate_search_report(report)


async def test_ask_question_with_different_llm() -> None:
temperature = 0.7
chosen_model = "gpt-3.5-turbo"
searcher = SmartSearcher(
model=chosen_model,
include_works_cited_list=True,
num_searches_to_run=1,
num_sites_per_search=3,
temperature=temperature,
)

assert isinstance(searcher.llm, GeneralLlm)
assert searcher.llm.model == chosen_model
assert searcher.llm.litellm_kwargs["temperature"] == pytest.approx(
temperature
)
assert searcher.llm.litellm_kwargs["model"] == chosen_model

question = "What is the recent news on SpaceX?"
report = await searcher.invoke(question)
logger.info(f"Report:\n{report}")
validate_search_report(report)


def validate_search_report(report: str) -> None:
assert report, "Result should not be empty"
assert isinstance(report, str), "Result should be a string"

Expand Down Expand Up @@ -60,13 +83,3 @@ async def test_ask_question_empty_prompt() -> None:
searcher = SmartSearcher()
with pytest.raises(ValueError):
await searcher.invoke("")


@pytest.mark.skip("Run this when needed as it's purely a qualitative test")
async def test_screenshot_question_2() -> None:
with MonetaryCostManager() as cost_manager:
searcher = SmartSearcher(num_sites_to_deep_dive=2)
question = "Please tell me about the recent trends in the Federal Funds Effective Rate."
result = await searcher.invoke(question)
logger.info(f"Result: {result}")
logger.info(f"Cost: {cost_manager.current_usage}")
28 changes: 21 additions & 7 deletions code_tests/unit_tests/test_ai_models/models_to_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from litellm import model_cost

from forecasting_tools.ai_models.ai_utils.openai_utils import VisionMessageData
from forecasting_tools.ai_models.claude35sonnet import Claude35Sonnet
from forecasting_tools.ai_models.deepseek_r1 import DeepSeekR1
from forecasting_tools.ai_models.deprecated_model_classes.claude35sonnet import (
Claude35Sonnet,
)
from forecasting_tools.ai_models.deprecated_model_classes.deepseek_r1 import (
DeepSeekR1,
)
from forecasting_tools.ai_models.deprecated_model_classes.gpt4o import Gpt4o
from forecasting_tools.ai_models.deprecated_model_classes.gpt4ovision import (
Gpt4oVision,
)
from forecasting_tools.ai_models.deprecated_model_classes.gpto1 import GptO1
from forecasting_tools.ai_models.deprecated_model_classes.metaculus4o import (
Gpt4oMetaculusProxy,
)
from forecasting_tools.ai_models.deprecated_model_classes.perplexity import (
Perplexity,
)
from forecasting_tools.ai_models.exa_searcher import ExaSearcher
from forecasting_tools.ai_models.general_llm import GeneralLlm, ModelInputType
from forecasting_tools.ai_models.gpt4o import Gpt4o
from forecasting_tools.ai_models.gpt4ovision import Gpt4oVision
from forecasting_tools.ai_models.gpto1 import GptO1
from forecasting_tools.ai_models.metaculus4o import Gpt4oMetaculusProxy
from forecasting_tools.ai_models.model_interfaces.ai_model import AiModel
from forecasting_tools.ai_models.model_interfaces.incurs_cost import IncursCost
from forecasting_tools.ai_models.model_interfaces.outputs_text import (
Expand All @@ -29,7 +40,6 @@
from forecasting_tools.ai_models.model_interfaces.tokens_incur_cost import (
TokensIncurCost,
)
from forecasting_tools.ai_models.perplexity import Perplexity


class ModelsToTest:
Expand Down Expand Up @@ -135,6 +145,10 @@ def _all_tests(self) -> list[ModelTest]:
GeneralLlm(model="deepseek/deepseek-reasoner"),
self._get_cheap_user_message(),
),
ModelTest(
GeneralLlm(model="openrouter/openai/gpt-4o"),
self._get_cheap_user_message(),
),
]

def all_tests_with_names(self) -> list[tuple[str, ModelTest]]:
Expand Down
2 changes: 1 addition & 1 deletion code_tests/unit_tests/test_ai_models/test_outputs_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from forecasting_tools.ai_models.ai_utils.response_types import (
TextTokenCostResponse,
)
from forecasting_tools.ai_models.gpt4o import Gpt4o
from forecasting_tools.ai_models.deprecated_model_classes.gpt4o import Gpt4o
from forecasting_tools.ai_models.model_interfaces.ai_model import AiModel
from forecasting_tools.ai_models.model_interfaces.outputs_text import (
OutputsText,
Expand Down
22 changes: 0 additions & 22 deletions code_tests/unit_tests/test_ai_models/test_time_limited_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,6 @@
TIME_LIMITED_ERROR_MESSAGE = "Model must be TimeLimited"


@pytest.mark.skip(
"This test takes too long to run between threads for some reason (it passes but doesn't check for timeout within 10 sec). Not important enough to fix right now"
)
@pytest.mark.parametrize("subclass", ModelsToTest.TIME_LIMITED_LIST)
def test_ai_model_successfully_times_out(
mocker: Mock, subclass: type[AiModel]
) -> None:
if not issubclass(subclass, TimeLimitedModel):
raise ValueError(TIME_LIMITED_ERROR_MESSAGE)

subclass.TIMEOUT_TIME = 10

AiModelMockManager.mock_ai_model_direct_call_with_long_wait(
mocker, subclass
)
model = subclass()
model_input = model._get_cheap_input_for_invoke()

with pytest.raises(asyncio.exceptions.TimeoutError):
asyncio.run(model.invoke(model_input))


@pytest.mark.parametrize("subclass", ModelsToTest.TIME_LIMITED_LIST)
def test_ai_model_has_at_least_minimum_timeout(
mocker: Mock, subclass: type[AiModel]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from forecasting_tools.forecasting.forecast_bots.template_bot import (
TemplateBot,
from forecasting_tools.forecasting.helpers.prediction_extractor import (
PredictionExtractor,
)
from forecasting_tools.forecasting.questions_and_reports.numeric_report import (
Percentile,
Expand Down Expand Up @@ -202,8 +202,7 @@ def test_numeric_parsing(
expected_percentiles: list[Percentile],
question: NumericQuestion,
) -> None:
bot = TemplateBot()
numeric_distribution = bot._extract_forecast_from_numeric_rationale(
numeric_distribution = PredictionExtractor.extract_numeric_distribution_from_list_of_percentile_number_and_probability(
gpt_response, question
)
for declared_percentile, expected_percentile in zip(
Expand Down
26 changes: 19 additions & 7 deletions forecasting_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from forecasting_tools.ai_models.ai_utils.ai_misc import (
clean_indents as clean_indents,
)
from forecasting_tools.ai_models.claude35sonnet import (
from forecasting_tools.ai_models.deprecated_model_classes.claude35sonnet import (
Claude35Sonnet as Claude35Sonnet,
)
from forecasting_tools.ai_models.deepseek_r1 import DeepSeekR1 as DeepSeekR1
from forecasting_tools.ai_models.exa_searcher import ExaSearcher as ExaSearcher
from forecasting_tools.ai_models.gpt4o import Gpt4o as Gpt4o
from forecasting_tools.ai_models.gpt4ovision import Gpt4oVision as Gpt4oVision
from forecasting_tools.ai_models.metaculus4o import (
from forecasting_tools.ai_models.deprecated_model_classes.deepseek_r1 import (
DeepSeekR1 as DeepSeekR1,
)
from forecasting_tools.ai_models.deprecated_model_classes.gpt4o import (
Gpt4o as Gpt4o,
)
from forecasting_tools.ai_models.deprecated_model_classes.gpt4ovision import (
Gpt4oVision as Gpt4oVision,
)
from forecasting_tools.ai_models.deprecated_model_classes.metaculus4o import (
Gpt4oMetaculusProxy as Gpt4oMetaculusProxy,
)
from forecasting_tools.ai_models.perplexity import Perplexity as Perplexity
from forecasting_tools.ai_models.deprecated_model_classes.perplexity import (
Perplexity as Perplexity,
)
from forecasting_tools.ai_models.exa_searcher import ExaSearcher as ExaSearcher
from forecasting_tools.ai_models.general_llm import GeneralLlm as GeneralLlm
from forecasting_tools.ai_models.resource_managers.monetary_cost_manager import (
MonetaryCostManager as MonetaryCostManager,
)
Expand All @@ -36,6 +45,9 @@
from forecasting_tools.forecasting.helpers.metaculus_api import (
MetaculusApi as MetaculusApi,
)
from forecasting_tools.forecasting.helpers.prediction_extractor import (
PredictionExtractor as PredictionExtractor,
)
from forecasting_tools.forecasting.helpers.smart_searcher import (
SmartSearcher as SmartSearcher,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Final

import typing_extensions

from forecasting_tools.ai_models.model_interfaces.combined_llm_archetype import (
CombinedLlmArchetype,
)


@typing_extensions.deprecated(
"LLM calls will slowly be moved to the GeneralLlm class", category=None
)
class Claude35Sonnet(CombinedLlmArchetype):
# See Anthropic Limit on the account dashboard for most up-to-date limit
# Latest as of Nov 6 2024 is claude-2-5-sonnet-20241022
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import typing_extensions

from forecasting_tools.ai_models.model_interfaces.combined_llm_archetype import (
CombinedLlmArchetype,
)


@typing_extensions.deprecated(
"LLM calls will slowly be moved to the GeneralLlm class", category=None
)
class DeepSeekR1(CombinedLlmArchetype):
MODEL_NAME = "deepseek/deepseek-reasoner"
REQUESTS_PER_PERIOD_LIMIT: int = 8_000
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from typing import Final

import typing_extensions

from forecasting_tools.ai_models.model_interfaces.combined_llm_archetype import (
CombinedLlmArchetype,
)

logger = logging.getLogger(__name__)


@typing_extensions.deprecated(
"LLM calls will slowly be moved to the GeneralLlm class", category=None
)
class Gpt4o(CombinedLlmArchetype):
# See OpenAI Limit on the account dashboard for most up-to-date limit
MODEL_NAME: Final[str] = "gpt-4o"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Final

import typing_extensions

from forecasting_tools.ai_models.ai_utils.openai_utils import VisionMessageData
from forecasting_tools.ai_models.model_interfaces.combined_llm_archetype import (
CombinedLlmArchetype,
Expand All @@ -11,6 +13,9 @@ class Gpt4VisionInput(VisionMessageData):
pass


@typing_extensions.deprecated(
"LLM calls will slowly be moved to the GeneralLlm class", category=None
)
class Gpt4oVision(CombinedLlmArchetype):
MODEL_NAME: Final[str] = "gpt-4o"
REQUESTS_PER_PERIOD_LIMIT: Final[int] = (
Expand Down
12 changes: 12 additions & 0 deletions forecasting_tools/ai_models/deprecated_model_classes/gpto1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing_extensions

from forecasting_tools.ai_models.deprecated_model_classes.gpto1preview import (
GptO1Preview,
)


@typing_extensions.deprecated(
"LLM calls will slowly be moved to the GeneralLlm class", category=None
)
class GptO1(GptO1Preview):
MODEL_NAME: str = "o1"
Loading