Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix #8

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add default values for allowed metric names
  • Loading branch information
Вакунов Сергей committed Mar 6, 2024
commit 44d243fd118120f38c9aecf4db5cd73899217ad2
14 changes: 11 additions & 3 deletions abacus/auto_ab/params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Dict, Any, Optional, List
from typing import Dict, Any, Optional, List, Tuple

from pydantic.dataclasses import dataclass
from pydantic import validator, Field
Expand All @@ -20,6 +20,11 @@
class MismatchMetricError(Exception):
"""
Checks if user assigns custom metric functin along with custom metric name
Args:
metric_name:str
Metric name assigned by user via metric_name attr
msg: str
Error message template (default: 'You use default metric function with custom metric name "{}"! Reassign it with "metric" parameter.')
"""
def __init__(
self,
Expand Down Expand Up @@ -95,8 +100,9 @@ class HypothesisParams:
n_buckets (int, Optional): number of buckets.
strata (str, Optional): stratification column.
strata_weights (Dict[str, float], Optional): historical strata weights.
default_metrics: Optional[Tuple[str]] : default metric with autoselected metric functions
"""

alpha: Optional[float] = 0.05
beta: Optional[float] = 0.2
alternative: Optional[str] = "two-sided" # less, greater, two-sided
Expand All @@ -112,15 +118,17 @@ class HypothesisParams:
n_buckets: Optional[int] = 100
strata: Optional[str] = ""
strata_weights: Optional[Dict[str, float]] = Field(default_factory=dict)
default_metrics: Optional[Tuple[str, ...]] = ('mean', 'median')

def __post_init__(self):
if self.metric_name not in ('mean', 'median') \
if self.metric_name not in self.default_metrics \
and self.metric is np.mean:
raise MismatchMetricError(self.metric_name)
if self.metric_name == "mean":
self.metric = np.mean
if self.metric_name == "median":
self.metric = np.median


@validator("alpha", always=True, allow_reuse=True)
@classmethod
Expand Down
Loading