Skip to content

Commit

Permalink
Restore get_remote_dataset signature (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaazzini authored Oct 26, 2021
1 parent 72859a6 commit 432b563
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions darwin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def list_remote_datasets(self, team: Optional[str] = None) -> Iterator[RemoteDat
client=self,
)

def get_remote_dataset(self, a_dataset_identifier: Union[str, DatasetIdentifier]) -> RemoteDataset:
def get_remote_dataset(self, dataset_identifier: Union[str, DatasetIdentifier]) -> RemoteDataset:
"""
Get a remote dataset based on the parameter passed.
Expand All @@ -306,37 +306,39 @@ def get_remote_dataset(self, a_dataset_identifier: Union[str, DatasetIdentifier]
RemoteDataset
Initialized dataset
"""
dataset_identifier: DatasetIdentifier = DatasetIdentifier.parse(a_dataset_identifier)
parsed_dataset_identifier: DatasetIdentifier = DatasetIdentifier.parse(dataset_identifier)

if not dataset_identifier.team_slug:
dataset_identifier.team_slug = self.default_team
if not parsed_dataset_identifier.team_slug:
parsed_dataset_identifier.team_slug = self.default_team

try:
matching_datasets: List[RemoteDataset] = [
dataset
for dataset in self.list_remote_datasets(team=dataset_identifier.team_slug)
if dataset.slug == dataset_identifier.dataset_slug
for dataset in self.list_remote_datasets(team=parsed_dataset_identifier.team_slug)
if dataset.slug == parsed_dataset_identifier.dataset_slug
]
except Unauthorized:
# There is a chance that we tried to access an open dataset
dataset = self.get(f"{dataset_identifier.team_slug}/{dataset_identifier.dataset_slug}")
dataset = self.get(f"{parsed_dataset_identifier.team_slug}/{parsed_dataset_identifier.dataset_slug}")

# If there isn't a record of this team, create one.
if not self.config.get_team(dataset_identifier.team_slug, raise_on_invalid_team=False):
if not self.config.get_team(parsed_dataset_identifier.team_slug, raise_on_invalid_team=False):
datasets_dir: Path = Path.home() / ".darwin" / "datasets"
self.config.set_team(team=dataset_identifier.team_slug, api_key="", datasets_dir=str(datasets_dir))
self.config.set_team(
team=parsed_dataset_identifier.team_slug, api_key="", datasets_dir=str(datasets_dir)
)

return RemoteDataset(
name=dataset["name"],
slug=dataset["slug"],
team=dataset_identifier.team_slug,
team=parsed_dataset_identifier.team_slug,
dataset_id=dataset["id"],
image_count=dataset["num_images"],
progress=0,
client=self,
)
if not matching_datasets:
raise NotFound(dataset_identifier)
raise NotFound(parsed_dataset_identifier)
return matching_datasets[0]

def create_dataset(self, name: str, team: Optional[str] = None) -> RemoteDataset:
Expand Down

0 comments on commit 432b563

Please sign in to comment.