Skip to content

Commit

Permalink
Merge pull request #19 from tochka-public/feat/add-node-store
Browse files Browse the repository at this point in the history
feat: PE-21 - Added node store
  • Loading branch information
Moootya authored Mar 28, 2024
2 parents 0527102 + d879853 commit 2cf77d5
Showing 9 changed files with 332 additions and 158 deletions.
45 changes: 0 additions & 45 deletions ml_pipeline_engine/context/dag.py
Original file line number Diff line number Diff line change
@@ -36,56 +36,11 @@ def __init__(
cls=self.chart.artifact_store or NoOpArtifactStore, ctx=self
)

self._local_cache = self.get_cache_object()
self._event_managers = [get_instance(cls) for cls in self.chart.event_managers]
self._case_results = {}
self._nodes_in_run = set()
self._active_recurrence_subgraph = {}

@classmethod
def get_cache_object(cls):
return Cache()

async def save_node_result(self, node_id: NodeId, data: t.Any) -> None:
self._local_cache.save(node_id=node_id, data=data)
await self.artifact_store.save(node_id=node_id, data=data)

async def load_node_result(self, node_id: NodeId) -> t.Any:
return self._local_cache.load(node_id=node_id)

async def add_case_result(self, switch_node_id: NodeId, selection: CaseResult) -> None:
self._case_results[switch_node_id] = selection

async def get_case_result(self, switch_node_id: NodeId) -> CaseResult:
return self._case_results[switch_node_id]

async def add_node_in_run(self, node_id: NodeId) -> None:
self._nodes_in_run.add(node_id)

async def is_node_in_run(self, node_id: NodeId) -> bool:
return node_id in self._nodes_in_run

async def exists_node_result(self, node_id: NodeId) -> bool:
return self._local_cache.exists(node_id)

async def is_active_recurrence_subgraph(self, source: NodeId, dest: NodeId) -> bool:
return (source, dest) in self._active_recurrence_subgraph

async def set_active_recurrence_subgraph(self, source: NodeId, dest: NodeId) -> None:
self._active_recurrence_subgraph[(source, dest)] = 1

async def remove_recurrence_subgraph(self, source: NodeId, dest: NodeId) -> None:
self._active_recurrence_subgraph.pop((source, dest))

def delete_node_results(self, node_ids: t.Iterable[NodeId]) -> None:
"""
Удаление всей информации, связанной с узлами
"""

for node_id in node_ids:
self._local_cache.remove(node_id)
self._nodes_in_run.discard(node_id)

@property
def model_name(self) -> ModelName:
return self.chart.model_name
3 changes: 2 additions & 1 deletion ml_pipeline_engine/dag/dag.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ async def run(self, ctx: PipelineContextLike) -> NodeResultT:
self.node_map.keys(),
),
dag=self,
ctx=ctx,
)

return await run_manager.run(ctx)
return await run_manager.run()
Loading

0 comments on commit 2cf77d5

Please sign in to comment.