Skip to content

Commit

Permalink
feat: boxplot shortcut (sinaptik-ai#308)
Browse files Browse the repository at this point in the history
* add .boxplot() shortcut

* update docs

* lint: fix lint

---------

Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
fp1acm8 and gventuri authored Jun 23, 2023
1 parent 95ae912 commit a09101f
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 25 deletions.
20 changes: 11 additions & 9 deletions docs/shortcuts.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ pandas_ai.plot_bar_chart(df, x = ['a', 'b', 'c'], y = [1, 2, 3])

This shortcut will plot a bar chart of the data frame.

### plot_bar_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_bar_chart(df, x = ['a', 'b', 'c'])
```

This shortcut will plot a bar chart of the data frame.

### plot_histogram

```python
Expand Down Expand Up @@ -112,6 +103,17 @@ pandas_ai.plot_roc_curve(df, y_true = [1, 2, 3], y_pred = [1, 2, 3])

This shortcut will plot a ROC curve of the data frame.

### boxplot

```python
df = pd.read_csv('data.csv')
pandas_ai.boxplot(df, col='A', by='B', style='Highlight outliers with a x')
```

This shortcut plots a box-and-whisker plot using the DataFrame `df`, focusing on the `'A'` column and grouping the data by the `'B'` column.

The `style` parameter allows users to communicate their desired plot customizations to the Language Model, providing flexibility for further refinement and adaptability to specific visual requirements.

### rolling_mean

```python
Expand Down
234 changes: 218 additions & 16 deletions pandasai/helpers/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,29 @@
class Shortcuts(ABC):
@abstractmethod
def run(self, df: pd.DataFrame, prompt: str) -> Union[str, pd.DataFrame]:
"""Run method from PandasAI class."""
"""
Run method from PandasAI class.
Args:
df (pd.DataFrame): The DataFrame containing the data.
prompt (str): The prompt to be executed.
Returns:
Union[str, pd.DataFrame]: The response from the LLM.
"""

pass

def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do data cleaning and return the dataframe."""
"""
Do data cleaning and return the dataframe.
Args:
df (pd.DataFrame): The DataFrame containing the data.
Returns:
pd.DataFrame: The cleaned DataFrame.
"""

return self.run(
df,
Expand All @@ -23,7 +40,15 @@ def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
)

def impute_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do missing value imputation and return the dataframe."""
"""
Do missing value imputation and return the dataframe.
Args:
df (pd.DataFrame): The DataFrame containing the data.
Returns:
pd.DataFrame: The DataFrame with imputed missing values.
"""

return self.run(
df,
Expand All @@ -35,7 +60,15 @@ def impute_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
)

def generate_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do feature generation and return the dataframe."""
"""
Do feature generation and return the dataframe.
Args:
df (pd.DataFrame): The DataFrame containing the data.
Returns:
pd.DataFrame: The DataFrame with generated features.
"""

return self.run(
df,
Expand All @@ -47,7 +80,17 @@ def generate_features(self, df: pd.DataFrame) -> pd.DataFrame:
)

def plot_pie_chart(self, df: pd.DataFrame, labels: list, values: list) -> None:
"""Plot a pie chart."""
"""
Plot a pie chart.
Args:
df (pd.DataFrame): The DataFrame containing the data.
labels (list): The labels for the pie chart.
values (list): The values for the pie chart.
Returns:
None
"""

self.run(
df,
Expand All @@ -59,7 +102,17 @@ def plot_pie_chart(self, df: pd.DataFrame, labels: list, values: list) -> None:
)

def plot_bar_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a bar chart."""
"""
Plot a bar chart.
Args:
df (pd.DataFrame): The DataFrame containing the data.
x (list): The x values for the bar chart.
y (list): The y values for the bar chart.
Returns:
None
"""

self.run(
df,
Expand All @@ -71,12 +124,31 @@ def plot_bar_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
)

def plot_histogram(self, df: pd.DataFrame, column: str) -> None:
"""Plot a histogram."""
"""
Plot a histogram.
Args:
df (pd.DataFrame): The DataFrame containing the data.
column (str): The column to plot the histogram for.
Returns:
None
"""

self.run(df, f"Plot a histogram of the column {column}.")

def plot_line_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a line chart."""
"""
Plot a line chart.
Args:
df (pd.DataFrame): The DataFrame containing the data.
x (list): The x values for the line chart.
y (list): The y values for the line chart.
Returns:
None
"""

self.run(
df,
Expand All @@ -88,7 +160,17 @@ def plot_line_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
)

def plot_scatter_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a scatter chart."""
"""
Plot a scatter chart.
Args:
df (pd.DataFrame): The DataFrame containing the data.
x (list): The x values for the scatter chart.
y (list): The y values for the scatter chart.
Returns:
None
"""

self.run(
df,
Expand All @@ -100,14 +182,32 @@ def plot_scatter_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
)

def plot_correlation_heatmap(self, df: pd.DataFrame) -> None:
"""Plot a correlation heatmap."""
"""
Plot a correlation heatmap.
Args:
df (pd.DataFrame): The DataFrame containing the data.
Returns:
None
"""

self.run(df, "Plot a correlation heatmap.")

def plot_confusion_matrix(
self, df: pd.DataFrame, y_true: list, y_pred: list
) -> None:
"""Plot a confusion matrix."""
"""
Plot a confusion matrix.
Args:
df (pd.DataFrame): The DataFrame containing the data.
y_true (list): The true values.
y_pred (list): The predicted values.
Returns:
None
"""

self.run(
df,
Expand All @@ -119,7 +219,17 @@ def plot_confusion_matrix(
)

def plot_roc_curve(self, df: pd.DataFrame, y_true: list, y_pred: list) -> None:
"""Plot a ROC curve."""
"""
Plot a ROC curve.
Args:
df (pd.DataFrame): The DataFrame containing the data.
y_true (list): The true values.
y_pred (list): The predicted values.
Returns:
None
"""

self.run(
df,
Expand All @@ -130,8 +240,70 @@ def plot_roc_curve(self, df: pd.DataFrame, y_true: list, y_pred: list) -> None:
""",
)

def boxplot(
self,
df: pd.DataFrame,
col: Union[str, list[str]] = None,
by: Union[str, list[str]] = None,
style: str = None,
):
"""
Draw a box plot to show distributions with respect to categories.
Args:
df (pd.DataFrame): The DataFrame containing the data.
col (str | list[str] | None): The column(s) of interest
for the box plot. Defaults to None.
by (str | list[str] | None): The grouping variable(s)
for the box plot. Defaults to None.
style (str | None): The textual description of the desired
style. Defaults to None.
Returns:
str: LLM response
"""

if not isinstance(col, (str, list, type(None))):
raise TypeError(
"The 'col' argument must be a string, a list of strings, or None."
)
if not isinstance(by, (str, list, type(None))):
raise TypeError(
"The 'by' argument must be a string, a list of strings, or None."
)

prompt = "Plot a box-and-whisker plot"

if isinstance(col, str):
prompt += f" for the variable '{col}'"
elif isinstance(col, list):
var_list = [f"'{v}'" for v in col]
if len(var_list) > 1:
variables_str = ", ".join(var_list[:-1])
prompt += f" for the variables {variables_str} and {var_list[-1]}"
else:
prompt += f" for the variable {var_list[0]}"

if by is not None:
prompt += f" grouped by '{by}'"

if style is not None:
prompt += f"\nStyle: '''{style}'''"

self.run(df, prompt)

def rolling_mean(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFrame:
"""Calculate the rolling mean."""
"""
Calculate the rolling mean.
Args:
df (pd.DataFrame): The DataFrame containing the data.
column (str): The column to calculate the rolling mean for.
window (int): The window size.
Returns:
pd.DataFrame: The DataFrame containing the rolling mean.
"""

return self.run(
df,
Expand All @@ -142,7 +314,17 @@ def rolling_mean(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFra
def rolling_median(
self, df: pd.DataFrame, column: str, window: int
) -> pd.DataFrame:
"""Calculate the rolling median."""
"""
Calculate the rolling median.
Args:
df (pd.DataFrame): The DataFrame containing the data.
column (str): The column to calculate the rolling median for.
window (int): The window size.
Returns:
pd.DataFrame: The DataFrame containing the rolling median.
"""

return self.run(
df,
Expand All @@ -151,7 +333,17 @@ def rolling_median(
)

def rolling_std(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFrame:
"""Calculate the rolling standard deviation."""
"""
Calculate the rolling standard deviation.
Args:
df (pd.DataFrame): The DataFrame containing the data.
column (str): The column to calculate the rolling standard deviation for.
window (int): The window size.
Returns:
pd.DataFrame: The DataFrame containing the rolling standard deviation.
"""

return self.run(
df,
Expand All @@ -162,7 +354,17 @@ def rolling_std(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFram
def segment_customers(
self, df: pd.DataFrame, features: list, n_clusters: int
) -> pd.DataFrame:
"""Segment customers."""
"""
Segment customers.
Args:
df (pd.DataFrame): The DataFrame containing the data.
features (list): The features to use for the segmentation.
n_clusters (int): The number of clusters.
Returns:
pd.DataFrame: The DataFrame containing the segmentation.
"""

return self.run(
df,
Expand Down

0 comments on commit a09101f

Please sign in to comment.