Skip to content

Commit

Permalink
Clean up Signed-off-by: Steven Zimmerman <[email protected]>
Browse files Browse the repository at this point in the history
  • Loading branch information
SZim92 committed Aug 11, 2024
1 parent 3e51fdf commit 8baa2b7
Showing 1 changed file with 6 additions and 21 deletions.
27 changes: 6 additions & 21 deletions service/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,18 @@
>>> print(output.shape) # Output: torch.Size([64, 128, 512])
"""

import os
import logging
import warnings
import torch
from functools import cache
from typing import Tuple, Optional
from pydantic import BaseSettings, Field, validator

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
from dotenv import load_dotenv
load_dotenv()

# Define memory threshold constants
SDPA_SLICE_TRIGGER_RATE = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 6))
ATTENTION_SLICE_RATE = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))

# Define Settings class using Pydantic
class Settings(BaseSettings):
"""
Configuration settings for attention slicing, loaded from environment variables.
Expand All @@ -69,10 +61,8 @@ class Settings(BaseSettings):
Raises:
ValueError: If environment variables are not set to positive values.
"""
sdpa_slice_trigger_rate: float = Field(default=SDPA_SLICE_TRIGGER_RATE,
env="IPEX_SDPA_SLICE_TRIGGER_RATE")
attention_slice_rate: float = Field(default=ATTENTION_SLICE_RATE,
env="IPEX_ATTENTION_SLICE_RATE")
sdpa_slice_trigger_rate: float = Field(default=6.0, env="IPEX_SDPA_SLICE_TRIGGER_RATE")
attention_slice_rate: float = Field(default=4.0, env="IPEX_ATTENTION_SLICE_RATE")

@validator('sdpa_slice_trigger_rate', 'attention_slice_rate')
def validate_positive(cls, v: float) -> float:
Expand All @@ -84,12 +74,6 @@ def validate_positive(cls, v: float) -> float:
# Instantiate settings
settings = Settings()

# Validate environment variables
if settings.sdpa_slice_trigger_rate <= 0 or settings.attention_slice_rate <= 0:
raise EnvironmentError("Environment variables for slicing are not set properly. "
"Please configure IPEX_SDPA_SLICE_TRIGGER_RATE and "
"IPEX_ATTENTION_SLICE_RATE to positive values.")

def validate_tensor_shape(tensor: torch.Tensor, expected_dim: int) -> None:
"""Helper function to validate tensor dimensions."""
if tensor.dim() != expected_dim:
Expand Down Expand Up @@ -261,7 +245,7 @@ def torch_bmm_32_bit(input: torch.Tensor, mat2: torch.Tensor, *, out: Optional[t
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, :, start_idx_3:end_idx_3],
)
)
else:
# Slice tensors and perform bmm.
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
Expand All @@ -272,13 +256,14 @@ def torch_bmm_32_bit(input: torch.Tensor, mat2: torch.Tensor, *, out: Optional[t
# Slice tensors and perform bmm.
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
mat2[start_idx:end_idx],
)
if out is not None:
out.copy_(hidden_states)
return out
torch.xpu.synchronize(input.device)
return hidden_states
# No slicing needed
logger.info("No slicing required for torch_bmm_32_bit. Using original bmm.")
return original_torch_bmm(input, mat2, out=out)

Expand Down

0 comments on commit 8baa2b7

Please sign in to comment.