Skip to content

Commit

Permalink
imports for GPU packages from compat (#119)
Browse files Browse the repository at this point in the history
* imports for GPU packages from compat

* get horovod to pass import problems

---------

Co-authored-by: Karl Higley <[email protected]>
  • Loading branch information
jperez999 and karlhigley authored Apr 4, 2023
1 parent fa716be commit 0bfd68f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 24 deletions.
18 changes: 3 additions & 15 deletions merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,10 @@
import warnings
from typing import List, Optional

import numpy as np
import pandas as pd

try:
import cupy
except ImportError:
cupy = None

from merlin.core.compat import cupy
from merlin.core.compat import HAS_GPU, cudf, cupy
from merlin.core.compat import numpy as np
from merlin.core.compat import pandas as pd
from merlin.core.dispatch import (
HAS_GPU,
annotate,
concat,
generate_local_seed,
Expand All @@ -45,11 +38,6 @@
from merlin.schema import Schema, Tags
from merlin.table import TensorTable

try:
import cudf
except ImportError:
cudf = None


def _num_steps(num_samples, step_size):
return math.ceil(num_samples / step_size)
Expand Down
11 changes: 7 additions & 4 deletions merlin/dataloader/utils/tf/tf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import logging
import os

import cupy

from merlin.core.compat import cupy, numpy
from merlin.io import Dataset

# we can control how much memory to give tensorflow with this environment variable
Expand All @@ -21,6 +20,7 @@
import tensorflow as tf # noqa: E402 isort:skip
import horovod.tensorflow as hvd # noqa: E402 isort:skip

xp = cupy or numpy

LOG = logging.getLogger("multi")

Expand All @@ -46,7 +46,7 @@
hvd.init()

# Seed with system randomness (or a static seed)
cupy.random.seed(None)
xp.random.seed(None)


def seed_fn():
Expand All @@ -60,7 +60,10 @@ def seed_fn():
max_rand = max_int // hvd.size()

# Generate a seed fragment on each worker
seed_fragment = cupy.random.randint(0, max_rand).get()
if cupy:
seed_fragment = xp.random.randint(0, max_rand).get()
else:
seed_fragment = xp.random.randint(0, max_rand)

# Aggregate seed fragments from all Horovod workers
seed_tensor = tf.constant(seed_fragment)
Expand Down
9 changes: 4 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import random

import dask
import numpy as np
import pandas as pd
from npy_append_array import NpyAppendArray

try:
import cudf
from merlin.core.compat import cudf
from merlin.core.compat import numpy as np

if cudf:
try:
import cudf.testing._utils

Expand All @@ -33,8 +33,7 @@
import cudf.tests.utils

assert_eq = cudf.tests.utils.assert_eq
except ImportError:
cudf = None
else:

def assert_eq(a, b, *args, **kwargs):
if isinstance(a, pd.DataFrame):
Expand Down

0 comments on commit 0bfd68f

Please sign in to comment.