Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions questions/serializers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from questions.serializers.aggregate_forecasts import serialize_question_aggregations
from questions.services.multiple_choice_handlers import get_all_options_from_history
from questions.types import QuestionMovement
from questions.types import OptionsHistoryType, QuestionMovement
from users.models import User
from utils.the_math.formulas import (
get_scaled_quartiles_from_cdf,
Expand Down Expand Up @@ -400,7 +400,7 @@ class ForecastWriteSerializer(serializers.ModelSerializer):

probability_yes = serializers.FloatField(allow_null=True, required=False)
probability_yes_per_category = serializers.DictField(
child=serializers.FloatField(), allow_null=True, required=False
child=serializers.FloatField(allow_null=True), allow_null=True, required=False
)
continuous_cdf = serializers.ListField(
child=serializers.FloatField(),
Expand Down Expand Up @@ -441,21 +441,47 @@ def binary_validation(self, probability_yes):
)
return probability_yes

def multiple_choice_validation(self, probability_yes_per_category, options):
def multiple_choice_validation(
self,
probability_yes_per_category: dict[str, float | None],
current_options: list[str],
options_history: OptionsHistoryType | None,
):
if probability_yes_per_category is None:
raise serializers.ValidationError(
"probability_yes_per_category is required"
)
if not isinstance(probability_yes_per_category, dict):
raise serializers.ValidationError("Forecast must be a dictionary")
if set(probability_yes_per_category.keys()) != set(options):
raise serializers.ValidationError("Forecast must include all options")
values = [float(probability_yes_per_category[option]) for option in options]
if not all([0.001 <= v <= 0.999 for v in values]) or not np.isclose(
sum(values), 1
):
if not set(current_options).issubset(set(probability_yes_per_category.keys())):
raise serializers.ValidationError(
f"Forecast must reflect current options: {current_options}"
)
all_options = get_all_options_from_history(options_history)
if not set(probability_yes_per_category.keys()).issubset(set(all_options)):
raise serializers.ValidationError(
"Forecast contains probabilities for unknown options"
)

values: list[float | None] = []
for option in all_options:
value = probability_yes_per_category.get(option, None)
if option in current_options:
if (value is None) or (not (0.001 <= value <= 0.999)):
raise serializers.ValidationError(
"Probabilities for current options must be between 0.001 and 0.999"
)
elif value is not None:
raise serializers.ValidationError(
f"Probability for inactivate option '{option}' must be null or absent"
)
values.append(value)
if not np.isclose(sum(filter(None, values)), 1):
raise serializers.ValidationError(
"All probabilities must be between 0.001 and 0.999 and sum to 1.0"
"Forecast values must sum to 1.0. "
f"Received {probability_yes_per_category} which is interpreted as "
f"values: {values} representing {all_options} "
f"with current options {current_options}"
)
return values

Expand Down Expand Up @@ -562,7 +588,7 @@ def validate(self, data):
"provided for multiple choice questions"
)
data["probability_yes_per_category"] = self.multiple_choice_validation(
probability_yes_per_category, question.options
probability_yes_per_category, question.options, question.options_history
)
else: # Continuous question
if probability_yes or probability_yes_per_category:
Expand Down Expand Up @@ -631,6 +657,21 @@ def serialize_question(
archived_scores = question.user_archived_scores
user_forecasts = question.request_user_forecasts
last_forecast = user_forecasts[-1] if user_forecasts else None
# if the user has a pre-registered forecast,
# replace the current forecast and anything after it
if question.type == Question.QuestionType.MULTIPLE_CHOICE:
# Right now, Multiple Choice is the only type that can have pre-registered
# forecasts
if last_forecast and last_forecast.start_time > timezone.now():
user_forecasts = [
f for f in user_forecasts if f.start_time < timezone.now()
]
if user_forecasts:
last_forecast.start_time = user_forecasts[-1].start_time
user_forecasts[-1] = last_forecast
else:
last_forecast.start_time = timezone.now()
user_forecasts = [last_forecast]
if (
last_forecast
and last_forecast.end_time
Expand All @@ -645,11 +686,7 @@ def serialize_question(
many=True,
).data,
"latest": (
MyForecastSerializer(
user_forecasts[-1],
).data
if user_forecasts
else None
MyForecastSerializer(last_forecast).data if last_forecast else None
),
"score_data": dict(),
}
Expand Down
64 changes: 56 additions & 8 deletions questions/services/forecasts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from datetime import timedelta
from typing import cast, Iterable
from datetime import datetime, timedelta, timezone as dt_timezone
from typing import cast, Iterable, Literal

import sentry_sdk
from django.db import transaction
Expand All @@ -13,6 +13,7 @@
from posts.models import PostUserSnapshot, PostSubscription
from posts.services.subscriptions import create_subscription_cp_change
from posts.tasks import run_on_post_forecast
from questions.services.multiple_choice_handlers import get_all_options_from_history
from scoring.models import Score
from users.models import User
from utils.cache import cache_per_object
Expand All @@ -34,28 +35,75 @@

def create_forecast(
*,
question: Question = None,
user: User = None,
continuous_cdf: list[float] = None,
probability_yes: float = None,
probability_yes_per_category: list[float] = None,
distribution_input=None,
question: Question,
user: User,
continuous_cdf: list[float] | None = None,
probability_yes: float | None = None,
probability_yes_per_category: list[float | None] | None = None,
distribution_input: dict | None = None,
end_time: datetime | None = None,
source: Forecast.SourceChoices | Literal[""] | None = None,
**kwargs,
):
now = timezone.now()
post = question.get_post()
source = source or ""

# delete all future-dated predictions, as this one will override them
Forecast.objects.filter(question=question, author=user, start_time__gt=now).delete()

# if the forecast to be created is for a multiple choice question during a grace
# period, we need to agument the forecast accordingly (possibly preregister)
if question.type == Question.QuestionType.MULTIPLE_CHOICE:
if not probability_yes_per_category:
raise ValueError("probability_yes_per_category required for MC questions")
options_history = question.options_history
if options_history and len(options_history) > 1:
period_end = datetime.fromisoformat(options_history[-1][0]).replace(
tzinfo=dt_timezone.utc
)
if period_end > now:
all_options = get_all_options_from_history(question.options_history)
prior_options = options_history[-2][1]
if end_time is None or end_time > period_end:
# create a pre-registration for the given forecast
Forecast.objects.create(
question=question,
author=user,
start_time=period_end,
end_time=end_time,
probability_yes_per_category=probability_yes_per_category,
post=post,
source=Forecast.SourceChoices.AUTOMATIC,
**kwargs,
)
end_time = period_end

prior_pmf: list[float | None] = [None] * len(all_options)
for i, (option, value) in enumerate(
zip(all_options, probability_yes_per_category)
):
if value is None:
continue
if option in prior_options:
prior_pmf[i] = (prior_pmf[i] or 0.0) + value
else:
prior_pmf[-1] = (prior_pmf[-1] or 0.0) + value
probability_yes_per_category = prior_pmf

forecast = Forecast.objects.create(
question=question,
author=user,
start_time=now,
end_time=end_time,
continuous_cdf=continuous_cdf,
probability_yes=probability_yes,
probability_yes_per_category=probability_yes_per_category,
distribution_input=(
distribution_input if question.type in QUESTION_CONTINUOUS_TYPES else None
),
post=post,
source=source,
**kwargs,
)
# tidy up all forecasts
Expand Down
2 changes: 1 addition & 1 deletion questions/services/multiple_choice_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def multiple_choice_add_options(
for forecast in user_forecasts:
pmf = forecast.probability_yes_per_category
forecast.probability_yes_per_category = (
pmf[:-1] + [0.0] * len(options_to_add) + [pmf[-1]]
pmf[:-1] + [None] * len(options_to_add) + [pmf[-1]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to mc/3800

)
if forecast.start_time < grace_period_end and (
forecast.end_time is None or forecast.end_time > grace_period_end
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_questions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def question_multiple_choice():
return create_question(
question_type=Question.QuestionType.MULTIPLE_CHOICE,
options=["a", "b", "c", "d"],
options_history=[("0001-01-01T00:00:00", ["a", "b", "c", "d"])],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,41 @@ def test_multiple_choice_reorder_options(
[],
False,
), # initial forecast is invalid
(
["a", "b", "other"],
["b"],
[
Forecast(
start_time=dt(2023, 1, 1),
end_time=dt(2024, 1, 1),
probability_yes_per_category=[0.6, 0.15, 0.25],
),
Forecast(
start_time=dt(2024, 1, 1),
end_time=None,
probability_yes_per_category=[0.2, 0.3, 0.5],
),
],
[
Forecast(
start_time=dt(2023, 1, 1),
end_time=dt(2024, 1, 1),
probability_yes_per_category=[0.6, 0.15, 0.25],
),
Forecast(
start_time=dt(2024, 1, 1),
end_time=dt(2025, 1, 1),
probability_yes_per_category=[0.2, 0.3, 0.5],
),
Forecast(
start_time=dt(2025, 1, 1),
end_time=None,
probability_yes_per_category=[0.2, None, 0.8],
source=Forecast.SourceChoices.AUTOMATIC,
),
],
True,
), # preserve previous forecasts
],
)
def test_multiple_choice_delete_options(
Expand Down Expand Up @@ -327,6 +362,36 @@ def test_multiple_choice_delete_options(
],
True,
), # no effect
(
["a", "b", "other"],
["c"],
dt(2025, 1, 1),
[
Forecast(
start_time=dt(2023, 1, 1),
end_time=dt(2024, 1, 1),
probability_yes_per_category=[0.6, 0.15, 0.25],
),
Forecast(
start_time=dt(2024, 1, 1),
end_time=None,
probability_yes_per_category=[0.2, 0.3, 0.5],
),
],
[
Forecast(
start_time=dt(2023, 1, 1),
end_time=dt(2024, 1, 1),
probability_yes_per_category=[0.6, 0.15, None, 0.25],
),
Forecast(
start_time=dt(2024, 1, 1),
end_time=dt(2025, 1, 1),
probability_yes_per_category=[0.2, 0.3, None, 0.5],
),
],
True,
), # edit all forecasts including old
],
)
def test_multiple_choice_add_options(
Expand Down
Loading