Skip to content

Commit

Permalink
Re-organize GenAI package init (copy of D64770837) (pytorch#3268)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3268

X-link: facebookresearch/FBGEMM#368

- Re-organize GenAI package init (copy of D64770837)

Reviewed By: jwfromm

Differential Revision: D64795669

fbshipit-source-id: 825e14e4dbc050403f47f4c6aa7b87664b35f38d
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 23, 2024
1 parent c195f87 commit 97d32bc
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 46 deletions.
6 changes: 3 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ set_source_files_properties(${experimental_gen_ai_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES
"${fbgemm_sources_include_directories}")

set(experimental_gen_ai_python_source_files
gen_ai/__init__.py)

file(GLOB_RECURSE experimental_gen_ai_python_source_files
RELATIVE gen_ai
*.py)

################################################################################
# FBGEMM_GPU HIP Code Generation
Expand Down
51 changes: 8 additions & 43 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,11 @@

# pyre-strict

import os

import torch

try:
# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu import open_source

# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu.docs.version import __version__ # noqa: F401
except Exception:
open_source: bool = False

# pyre-ignore[16]
if open_source:
torch.ops.load_library(
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_gen_ai_py.so")
)
torch.classes.load_library(
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_gen_ai_py.so")
)
else:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:attention_ops"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:comm_ops"
)
from fbgemm_gpu.experimental.gen_ai import comm_ops # noqa: F401

torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gemm_ops"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops"
)
from fbgemm_gpu.experimental.gen_ai import quantize_ops # noqa: F401

torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:kv_cache_ops"
)
# Load custom operator libraries and register shape functions.
from . import ( # noqa: F401
attention_ops,
comm_ops,
gemm_ops,
kv_cache_ops,
quantize_ops,
)
14 changes: 14 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/attention_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from fbgemm_gpu.experimental.gen_ai.utils.loader import load_custom_library

# Load all custom attention operators.
load_custom_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:attention_ops"
)
5 changes: 5 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import torch

from fbgemm_gpu.experimental.gen_ai.utils.loader import load_custom_library

"""
This file contains manual shape registrations for communication custom operators.
These are needed for custom operators to be compatible with torch.compile.
Expand All @@ -20,6 +22,9 @@
in python.
"""

# Load all custom operators.
load_custom_library("//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:comm_ops")


@torch.library.register_fake("fbgemm::nccl_allreduce")
def nccl_allreduce_abstract(
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/gemm_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from fbgemm_gpu.experimental.gen_ai.utils.loader import load_custom_library

# Load all custom gemm operators.
load_custom_library("//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gemm_ops")
12 changes: 12 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/kv_cache_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from fbgemm_gpu.experimental.gen_ai.utils.loader import load_custom_library

# Load all custom kv cache operators.
load_custom_library("//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:kv_cache_ops")
5 changes: 5 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import torch

from fbgemm_gpu.experimental.gen_ai.utils.loader import load_custom_library

"""
This file contains manual shape registrations for quantize custom operators.
These are needed for custom operators to be compatible with torch.compile.
Expand All @@ -20,6 +22,9 @@
in python.
"""

# Load all custom operators.
load_custom_library("//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops")


@torch.library.register_fake("fbgemm::f8f8bf16_blockwise")
def f8f8bf16_blockwise_abstract(
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
48 changes: 48 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/utils/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os

import torch


def load_custom_library(lib_name: str) -> None:
"""
Load a custom library implemented in C++. This
helper function handles loading libraries both in
fbcode and OSS.
"""
try:
# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu import open_source

# pyre-ignore[21]
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
from fbgemm_gpu.docs.version import __version__ # noqa: F401
except Exception:
open_source: bool = False

# In open source, all custom ops are packaged into a single library
# that we load.
# pyre-ignore[16]
if open_source:
torch.ops.load_library(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"fbgemm_gpu_experimental_gen_ai_py.so",
)
)
torch.classes.load_library(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"fbgemm_gpu_experimental_gen_ai_py.so",
)
)
else:
torch.ops.load_library(lib_name)

0 comments on commit 97d32bc

Please sign in to comment.