diff --git a/python/dgl/frame.py b/python/dgl/frame.py index b6ddf55e3dad..615592eb7306 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -115,7 +115,10 @@ def data(self): # copy index to the same context of storage. # Copy index is usually cheaper than copy data if F.context(self.storage) != F.context(self.index): - self.index = F.copy_to(self.index, F.context(self.storage)) + kwargs = {} + if self.device is not None: + kwargs = self.device[1] + self.index = F.copy_to(self.index, F.context(self.storage), **kwargs) self.storage = F.gather_row(self.storage, self.index) self.index = None @@ -148,8 +151,6 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name """ col = self.clone() col.device = (device, kwargs) - if self.index is not None: - col.index = F.copy_to(self.index, device) return col def __getitem__(self, rowids): @@ -253,6 +254,12 @@ def subcolumn(self, rowids): if self.index is None: return Column(self.storage, self.scheme, rowids, self.device) else: + if F.context(self.index) != F.context(rowids): + # make sure index and row ids are on the same context + kwargs = {} + if self.device is not None: + kwargs = self.device[1] + rowids = F.copy_to(rowids, F.context(self.index), **kwargs) return Column(self.storage, self.scheme, F.gather_row(self.index, rowids), self.device) @staticmethod