Skip to content

Commit

Permalink
[td] Fix charts when using Spark (mage-ai#5028)
Browse files Browse the repository at this point in the history
# Description

- Serialize and deserialize almost anything
- Improved display of output in notebook
- Fixed charts on dashboard and in notebook
- updated linters and formatters

---------

Co-authored-by: mager <[email protected]>
  • Loading branch information
tommydangerous and mager authored May 11, 2024
1 parent 052227c commit 8acafbd
Show file tree
Hide file tree
Showing 80 changed files with 9,273 additions and 4,486 deletions.
32 changes: 17 additions & 15 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ RUN \
curl https://packages.microsoft.com/config/debian/11/prod.list > /etc/apt/sources.list.d/mssql-release.list && \
apt-get -y update && \
ACCEPT_EULA=Y apt-get -y install --no-install-recommends \
# NFS dependencies
nfs-common \
# odbc dependencies
msodbcsql18\
unixodbc-dev \
# R
r-base && \
# NFS dependencies
nfs-common \
# odbc dependencies
msodbcsql18\
unixodbc-dev \
graphviz \
# R
r-base && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

Expand All @@ -26,6 +27,7 @@ RUN \
R -e "install.packages('pacman', repos='http://cran.us.r-project.org')" && \
R -e "install.packages('renv', repos='http://cran.us.r-project.org')"


## Python Packages
RUN \
pip3 install --no-cache-dir sparkmagic && \
Expand All @@ -42,20 +44,20 @@ RUN \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/dbt-synapse.git#egg=dbt-synapse" && \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/sqlglot#egg=sqlglot" && \
if [ -z "$FEATURE_BRANCH" ] || [ "$FEATURE_BRANCH" = "null" ]; then \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git#egg=mage-integrations&subdirectory=mage_integrations"; \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git#egg=mage-integrations&subdirectory=mage_integrations"; \
else \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git@$FEATURE_BRANCH#egg=mage-integrations&subdirectory=mage_integrations"; \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git@$FEATURE_BRANCH#egg=mage-integrations&subdirectory=mage_integrations"; \
fi

# Mage
COPY ./mage_ai/server/constants.py /tmp/constants.py
RUN if [ -z "$FEATURE_BRANCH" ] || [ "$FEATURE_BRANCH" = "null" ] ; then \
tag=$(tail -n 1 /tmp/constants.py) && \
VERSION=$(echo "$tag" | tr -d "'") && \
pip3 install --no-cache-dir "mage-ai[all]==$VERSION"; \
else \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git@$FEATURE_BRANCH#egg=mage-ai[all]"; \
fi
tag=$(tail -n 1 /tmp/constants.py) && \
VERSION=$(echo "$tag" | tr -d "'") && \
pip3 install --no-cache-dir "mage-ai[all]==$VERSION"; \
else \
pip3 install --no-cache-dir "git+https://github.com/mage-ai/mage-ai.git@$FEATURE_BRANCH#egg=mage-ai[all]"; \
fi


## Startup Script
Expand Down
31 changes: 16 additions & 15 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@ RUN \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
NODE_MAJOR=20 && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_$NODE_MAJOR.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list && \
apt-get -y update && \
ACCEPT_EULA=Y apt-get -y install --no-install-recommends \
# Node
nodejs \
# NFS dependencies
nfs-common \
# odbc dependencies
msodbcsql18 \
unixodbc-dev && \
# R
# r-base=4.2.2.20221110-2 && \
apt-get update -y && \
ACCEPT_EULA=Y apt-get install -y --no-install-recommends \
# Node
nodejs \
# NFS dependencies
nfs-common \
# odbc dependencies
msodbcsql18 \
unixodbc-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

## R Packages
# RUN \
# R -e "install.packages('pacman', repos='http://cran.us.r-project.org')" && \
# R -e "install.packages('renv', repos='http://cran.us.r-project.org')"
## Chart packages
# Before fixing, ensure you have merged the chart packages installation step with another apt-get install step to adhere to best practices.
# If keeping as a separate RUN statement, ensure you follow the same pattern regarding list cleanup and `-y` switch as done for system packages.
RUN apt-get update -y && \
apt-get install -y --no-install-recommends graphviz && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

## Node Packages
RUN npm install --global yarn && yarn global add next
Expand Down
16 changes: 10 additions & 6 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ x-server_settings: &server_settings
- ECS_TASK_DEFINITION=$ECS_TASK_DEFINITION
- ENABLE_NEW_RELIC=$ENABLE_NEW_RELIC
- ENABLE_PROMETHEUS=$ENABLE_PROMETHEUS
- OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_EXPORTER_OTLP_ENDPOINT}
- OTEL_EXPORTER_OTLP_HTTP_ENDPOINT=${OTEL_EXPORTER_OTLP_HTTP_ENDPOINT}
- OTEL_PYTHON_TORNADO_EXCLUDED_URLS=$OTEL_PYTHON_TORNADO_EXCLUDED_URLS
- ENV=dev
- GCP_PROJECT_ID=$GCP_PROJECT_ID
- GCP_REGION=$GCP_REGION
Expand All @@ -35,21 +32,28 @@ x-server_settings: &server_settings
- LDAP_BIND_PASSWORD=$LDAP_BIND_PASSWORD
- LDAP_SERVER=$LDAP_SERVER
- MAGE_BASE_PATH=$MAGE_BASE_PATH
- MAGE_DATA_DIR=$MAGE_DATA_DIR
- MAGE_DATABASE_CONNECTION_URL=$DATABASE_CONNECTION_URL
- MAGE_DATA_DIR=$MAGE_DATA_DIR
- MAGE_PRESENTERS_DIRECTORY=$MAGE_PRESENTERS_DIRECTORY
- MAX_NUMBER_OF_FILE_VERSIONS=$MAX_NUMBER_OF_FILE_VERSIONS
- NEW_RELIC_CONFIG_PATH=$NEW_RELIC_CONFIG_PATH
- OPENAI_API_KEY=$OPENAI_API_KEY
- OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_EXPORTER_OTLP_ENDPOINT}
- OTEL_EXPORTER_OTLP_HTTP_ENDPOINT=${OTEL_EXPORTER_OTLP_HTTP_ENDPOINT}
- OTEL_PYTHON_TORNADO_EXCLUDED_URLS=$OTEL_PYTHON_TORNADO_EXCLUDED_URLS
- REQUIRE_USER_AUTHENTICATION=$REQUIRE_USER_AUTHENTICATION
- REQUIRE_USER_PERMISSIONS=$REQUIRE_USER_PERMISSIONS
- SERVER_VERBOSITY=$SERVER_VERBOSITY
- SCHEDULER_TRIGGER_INTERVAL=$SCHEDULER_TRIGGER_INTERVAL
- SERVER_LOGGING_TEMPLATE=${SERVER_LOGGING_TEMPLATE:-%(levelname)s:%(name)s:%(message)s}
- SERVER_VERBOSITY=$SERVER_VERBOSITY
- SMTP_EMAIL=$SMTP_EMAIL
- SMTP_PASSWORD=$SMTP_PASSWORD
- path_to_keyfile=$GCP_PATH_TO_CREDENTIALS
- SCHEDULER_TRIGGER_INTERVAL=$SCHEDULER_TRIGGER_INTERVAL
ports:
- 6789:6789
volumes:
- .:/home/src
- ./ml/mlops:/home/src/mlops
- ~/.aws:/root/.aws
- ~/.mage_data:/root/.mage_data
restart: on-failure:5
Expand Down
16 changes: 12 additions & 4 deletions docs/contributing/frontend/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@ sidebarTitle: "Overview"
description: "Guides on adding features to the front-end client."
---

| | |
| --- | --- |
| Language | TypeScript |
| Framework | React, Next.js |
| | |
| -------------- | ------------------- |
| Language | TypeScript |
| Framework | React, Next.js |
| Code directory | `mage_ai/frontend/` |

## Style guide

Mage follows the [Airbnb JavaScript style guide](https://airbnb.io/javascript/react/).

### Linter

To run the linter locally, execute this script:

```bash
./scripts/test.sh
```

## Setup

```bash
yarn global add eslint-config-next
yarn global add eslint_d
```
Empty file added mage_ai/ai/utils/__init__.py
Empty file.
220 changes: 220 additions & 0 deletions mage_ai/ai/utils/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import base64
import io
import json
import os
from typing import Any, Optional, Tuple

from mage_ai.data_preparation.models.variables.constants import (
CONFIG_JSON_FILE,
MEDIA_IMAGE_VISUALIZATION_FILE,
UBJSON_MODEL_FILENAME,
)
from mage_ai.settings.server import MAX_OUTPUT_IMAGE_PREVIEW_SIZE
from mage_ai.shared.environments import is_debug


# Function to check if the booster is trained
def is_booster_trained(booster: Any, raise_exception: bool = True) -> bool:
# Check if the booster has been trained by checking for the existence of an attribute
# like 'feature_names'; this attribute is set upon training.
# If the model is not trained, accessing this attribute should raise an AttributeError.
try:
import xgboost as xgb

# If feature_names exists, it means the model has been trained.
return booster.num_boosted_rounds() >= 1
except xgb.core.XGBoostError as err:
message = f"XGBoost model is not trained. {err}"

if raise_exception:
raise Exception(message)

print(message)

# If the model is not trained, an AttributeError is raised.
return False


def load_model(
model_dir: str,
model_filename: str = UBJSON_MODEL_FILENAME,
config_filename: str = CONFIG_JSON_FILE,
raise_exception: bool = True,
) -> Optional[Any]:
try:
import xgboost as xgb

model_path = os.path.join(model_dir, model_filename)
model = xgb.Booster()
model.load_model(model_path)

config_path = os.path.join(model_dir, config_filename)
with open(config_path, "r") as file:
model_config = json.load(file)

model_config_str = json.dumps(model_config)
# Apply the saved configuration to the model
model.load_config(model_config_str)

return model
except Exception as err:
if raise_exception or is_debug():
raise err
print(f"[ERROR] XGBoost.load_model: {err}")

return None


def save_model(
booster: Any,
model_dir: str,
model_filename: str = UBJSON_MODEL_FILENAME,
config_filename: str = CONFIG_JSON_FILE,
image_filename: str = MEDIA_IMAGE_VISUALIZATION_FILE,
raise_exception: bool = True,
) -> bool:
if not is_booster_trained(booster, raise_exception=raise_exception):
return False

os.makedirs(model_dir, exist_ok=True)

# Save detailed configuration of the model that includes all the hyperparameters
# and settings
model_path = os.path.join(model_dir, model_filename)
booster.save_model(model_path)

# Save the structure of the trees (for tree-based models like gradient boosting)
# along with some basic configurations necessary to understand the model structure itself
config_path = os.path.join(model_dir, config_filename)
with open(config_path, "w") as f:
f.write(booster.save_config())

if image_filename:
image_path = os.path.join(model_dir, image_filename)
try:
create_tree_visualization(booster, image_path=image_path)
except Exception as err:
print(f"[ERROR] XGBoost.load_model: {err}")

return True


def create_tree_visualization(
model: Any,
image_path: Optional[str] = None,
max_render_size: int = MAX_OUTPUT_IMAGE_PREVIEW_SIZE,
max_trees: int = 12,
num_trees: int = 0,
) -> Tuple[Optional[str], bool]:
try:
import xgboost as xgb
from PIL import Image

# Increase the maximum allowed image size to, say, 500 million pixels
Image.MAX_IMAGE_PIXELS = 1024 * 1024 * 500

trees_to_render = 0
n_trees = model.num_boosted_rounds()
if n_trees > max_trees or n_trees == 0:
trees_to_render = max_trees
elif n_trees >= 1 and num_trees < n_trees:
trees_to_render = num_trees
elif n_trees == max_trees:
trees_to_render = 0

graph = xgb.to_graphviz(
model, num_trees=trees_to_render, rankdir="TB", format="png"
)

if image_path:
# Remove the '.png' extension when specifying the filename
base_image_path = (
image_path.rsplit(".", 1)[0] if "." in image_path else image_path
)

# Pass the adjusted filename without the extension
graph.render(filename=base_image_path, cleanup=True, format="png")
# Since the 'format' is 'png', Graphviz will output 'visualization.png'
return image_path, True

# Save the graph to a temporary PNG file (or use BytesIO directly with some adjustments)
png_bytes = graph.pipe(format="png")

# Convert PNG bytes to an PIL Image object
image = Image.open(io.BytesIO(png_bytes))

# Save or Display your image directly if needed
# image.show() # Uncomment to display the image directly

# Convert the PIL Image to a base64 string
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()

buffered.seek(0) # Reset pointer to the beginning of the buffer
file_size_bytes = buffered.getbuffer().nbytes
if file_size_bytes > max_render_size:
message = (
"XGBoost tree visualization created an image that exceeds "
f"{max_render_size} bytes (actual size is {file_size_bytes} bytes). "
"No preview will be shown in the browser. "
"To increase the preview limit, set the environment variable "
"MAX_OUTPUT_IMAGE_PREVIEW_SIZE to a larger byte size value."
)
return message, False

return img_str, True
except Exception as err:
print(f"[ERROR] XGBoost.create_tree_visualization: {err}")
return str(err), False


def render_tree_visualization(
image_dir: str,
image_filename: str = MEDIA_IMAGE_VISUALIZATION_FILE,
max_render_size: int = MAX_OUTPUT_IMAGE_PREVIEW_SIZE,
) -> Tuple[Optional[str], bool]:
# Load the model’s tree from a PNG file into base64 format
try:
image_path = os.path.join(image_dir, image_filename)

# Check for file size before opening
file_size_bytes = os.path.getsize(image_path)

if file_size_bytes > max_render_size:
message = (
"XGBoost tree visualization created an image that exceeds "
f"{max_render_size} bytes (actual size is {file_size_bytes} bytes). "
"No preview will be shown in the browser. "
"To increase the preview limit, set the environment variable "
"MAX_OUTPUT_IMAGE_PREVIEW_SIZE to a larger byte size value."
)
return message, False

with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string, True
except Exception as err:
print(f"[ERROR] XGBoost.render_tree_visualization: {err}")
return str(err), False


def create_tree_plot(model: Any, image_path: str, num_trees: int = 0) -> str:
try:
import matplotlib.pyplot as plt
import xgboost as xgb

plt.close("all")

plt.figure(dpi=300)
xgb.plot_tree(model, num_trees=5)
plt.tight_layout()

plt.savefig(image_path, dpi=300, bbox_inches="tight")

plt.close()

return image_path
except Exception as err:
print(f"[ERROR] XGBoost.create_tree_plot: {err}")
return str(err)
Loading

0 comments on commit 8acafbd

Please sign in to comment.