Skip to content

Commit

Permalink
Merge pull request #8 from kolmogorov-lab/hotfix
Browse files Browse the repository at this point in the history
Hotfix
  • Loading branch information
egorshishkovets authored Mar 6, 2024
2 parents c78a9ec + 44d243f commit c70b1fa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ venv/
ENV/
env.bak/
venv.bak/
abacus_env/

# Spyder project settings
.spyderproject
Expand Down
31 changes: 29 additions & 2 deletions abacus/auto_ab/params.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List, Tuple

from pydantic.dataclasses import dataclass
from pydantic import validator, Field
import numpy as np




from abacus.types import (
ColumnNameType,
ColumnNamesType,
Expand All @@ -14,6 +17,24 @@
MetricTransformType,
)

class MismatchMetricError(Exception):
"""
Checks if user assigns custom metric functin along with custom metric name
Args:
metric_name:str
Metric name assigned by user via metric_name attr
msg: str
Error message template (default: 'You use default metric function with custom metric name "{}"! Reassign it with "metric" parameter.')
"""
def __init__(
self,
metric_name:str,
msg:str='You use default metric function with custom metric name "{}"! Reassign it with "metric" parameter.'
):
self.metric_name = metric_name
self.msg = msg.format(self.metric_name)
super().__init__(self.msg)


class ValidationConfig:
validate_assignment = True
Expand Down Expand Up @@ -79,8 +100,9 @@ class HypothesisParams:
n_buckets (int, Optional): number of buckets.
strata (str, Optional): stratification column.
strata_weights (Dict[str, float], Optional): historical strata weights.
default_metrics: Optional[Tuple[str]] : default metric with autoselected metric functions
"""

alpha: Optional[float] = 0.05
beta: Optional[float] = 0.2
alternative: Optional[str] = "two-sided" # less, greater, two-sided
Expand All @@ -96,12 +118,17 @@ class HypothesisParams:
n_buckets: Optional[int] = 100
strata: Optional[str] = ""
strata_weights: Optional[Dict[str, float]] = Field(default_factory=dict)
default_metrics: Optional[Tuple[str, ...]] = ('mean', 'median')

def __post_init__(self):
if self.metric_name not in self.default_metrics \
and self.metric is np.mean:
raise MismatchMetricError(self.metric_name)
if self.metric_name == "mean":
self.metric = np.mean
if self.metric_name == "median":
self.metric = np.median


@validator("alpha", always=True, allow_reuse=True)
@classmethod
Expand Down

0 comments on commit c70b1fa

Please sign in to comment.