Skip to content

Commit

Permalink
Allows overriding CPE configurations on NVD records (#502)
Browse files Browse the repository at this point in the history
* wip: add input writer to nvd provider

Signed-off-by: Will Murphy <[email protected]>

* feat: get new nvd_writer from input and generate reader from store

Signed-off-by: Christopher Phillips <[email protected]>

* wip: move input db writing to be during update

This way, we don't have to loop over the entire input table to write an
output table.

Signed-off-by: Will Murphy <[email protected]>

* chore: lint fix

Signed-off-by: Will Murphy <[email protected]>

* add download overrides

Signed-off-by: Alex Goodman <[email protected]>

* update for latest overrides layout

Signed-off-by: Weston Steimel <[email protected]>

* bump nvd provider to version 2

This will force any preload state of version 2 to be used if the current
workspace state is v1. This is needed because the nvd-input db will not
exist and will otherwise result in a full NVD api pull

Signed-off-by: Weston Steimel <[email protected]>

* fix static analysis

Signed-off-by: Alex Goodman <[email protected]>

* add and fix tests

Signed-off-by: Alex Goodman <[email protected]>

---------

Signed-off-by: Will Murphy <[email protected]>
Signed-off-by: Christopher Phillips <[email protected]>
Signed-off-by: Alex Goodman <[email protected]>
Signed-off-by: Weston Steimel <[email protected]>
Co-authored-by: Will Murphy <[email protected]>
Co-authored-by: Christopher Phillips <[email protected]>
Co-authored-by: Alex Goodman <[email protected]>
  • Loading branch information
4 people authored Mar 11, 2024
1 parent fc1a2e0 commit f4dbceb
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 24 deletions.
1 change: 0 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions src/vunnel/providers/nvd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Config:
)
request_timeout: int = 125
api_key: Optional[str] = "env:NVD_API_KEY" # noqa: UP007
overrides_url: str = "https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz"
overrides_enabled: bool = False

def __post_init__(self) -> None:
if self.api_key and self.api_key.startswith("env:"):
Expand All @@ -36,6 +38,8 @@ def __str__(self) -> str:


class Provider(provider.Provider):
__version__ = 2

def __init__(self, root: str, config: Config | None = None):
if not config:
config = Config()
Expand All @@ -50,12 +54,25 @@ def __init__(self, root: str, config: Config | None = None):
"(otherwise incremental updates will fail)",
)

if self.config.runtime.result_store != result.StoreStrategy.SQLITE:
raise ValueError(
f"only 'SQLITE' is supported for 'runtime.result_store' but got '{self.config.runtime.result_store}'",
)

if self.config.overrides_enabled and not self.config.overrides_url:
raise ValueError(
"if 'overrides_enabled' is set then 'overrides_url' must be set",
)

self.schema = schema.NVDSchema()
self.manager = Manager(
workspace=self.workspace,
schema=self.schema,
download_timeout=self.config.request_timeout,
api_key=self.config.api_key,
logger=self.logger,
overrides_enabled=self.config.overrides_enabled,
overrides_url=self.config.overrides_url,
)

@classmethod
Expand Down
140 changes: 127 additions & 13 deletions src/vunnel/providers/nvd/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import os
from typing import TYPE_CHECKING, Any

from .api import NvdAPI
from vunnel import result, schema
from vunnel.providers.nvd.api import NvdAPI
from vunnel.providers.nvd.overrides import NVDOverrides

if TYPE_CHECKING:
from collections.abc import Generator
Expand All @@ -14,12 +16,17 @@


class Manager:
def __init__(
__nvd_input_db__ = "nvd-input.db"

def __init__( # noqa: PLR0913
self,
workspace: Workspace,
schema: schema.Schema,
overrides_url: str,
logger: logging.Logger | None = None,
download_timeout: int = 125,
api_key: str | None = None,
overrides_enabled: bool = False,
) -> None:
self.workspace = workspace

Expand All @@ -28,19 +35,87 @@ def __init__(
self.logger = logger

self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout)

self.overrides = NVDOverrides(
enabled=overrides_enabled,
url=overrides_url,
workspace=workspace,
logger=logger,
download_timeout=download_timeout,
)

self.urls = [self.api._cve_api_url_] # noqa: SLF001
self.schema = schema

def get(
self,
last_updated: datetime.datetime | None,
skip_if_exists: bool = False,
) -> Generator[tuple[str, dict[str, Any]], Any, None]:
if skip_if_exists and self._can_update_incrementally(last_updated):
yield from self._download_updates(last_updated) # type: ignore # noqa: PGH003
else:
yield from self._download_all()
self.overrides.download()

cves_processed = set()
for record_id, record in self._download_nvd_input(last_updated, skip_if_exists):
cves_processed.add(id_to_cve(record_id))
yield record_id, record

if self.overrides.enabled:
self.urls.append(self.overrides.url)
self.logger.debug("applying NVD data overrides...")

override_cves = {cve.lower() for cve in self.overrides.cves()}
override_remaining_cves = override_cves - cves_processed
with self._sqlite_reader() as reader:
for cve in override_remaining_cves:

original_record = reader.read(cve_to_id(cve))
if not original_record:
self.logger.warning(f"override for {cve} not found in original data")
continue

original_record = original_record["item"]
if not original_record:
self.logger.warning(f"missing original data for {cve}")
continue

yield cve_to_id(cve), self._apply_override(cve, original_record)

self.logger.debug(f"applied overrides for {len(override_remaining_cves)} CVEs")

self.logger.debug("overrides are not enabled, skipping...")

def _download_nvd_input(
self,
last_updated: datetime.datetime | None,
skip_if_exists: bool = False,
) -> Generator[tuple[str, dict[str, Any]], Any, None]:
with self._nvd_input_writer() as writer:
if skip_if_exists and self._can_update_incrementally(last_updated):
yield from self._download_updates(last_updated, writer) # type: ignore # noqa: PGH003
else:
yield from self._download_all(writer)

def _nvd_input_writer(self) -> result.Writer:
return result.Writer(
workspace=self.workspace,
result_state_policy=result.ResultStatePolicy.KEEP,
logger=self.logger,
store_strategy=result.StoreStrategy.SQLITE,
write_location=self._input_nvd_path,
)

def _sqlite_reader(self) -> result.SQLiteReader:
return result.SQLiteReader(sqlite_db_path=self._input_nvd_path)

@property
def _input_nvd_path(self) -> str:
return os.path.join(self.workspace.input_path, self.__nvd_input_db__)

def _can_update_incrementally(self, last_updated: datetime.datetime | None) -> bool:
input_db_path = os.path.join(self.workspace.input_path, self.__nvd_input_db__)
if not os.path.exists(input_db_path):
return False

if not last_updated:
return False

Expand All @@ -55,15 +130,19 @@ def _can_update_incrementally(self, last_updated: datetime.datetime | None) -> b

return True

def _download_all(self) -> Generator[tuple[str, dict[str, Any]], Any, None]:
def _download_all(self, writer: result.Writer) -> Generator[tuple[str, dict[str, Any]], Any, None]:
self.logger.info("downloading all CVEs")

# TODO: should we delete all existing state in this case first?

for response in self.api.cve():
yield from self._unwrap_records(response)
yield from self._unwrap_records(response, writer)

def _download_updates(self, last_updated: datetime.datetime) -> Generator[tuple[str, dict[str, Any]], Any, None]:
def _download_updates(
self,
last_updated: datetime.datetime,
writer: result.Writer,
) -> Generator[tuple[str, dict[str, Any]], Any, None]:
self.logger.debug(f"downloading CVEs changed since {last_updated.isoformat()}")

# get the list of CVEs that have been updated since the last sync
Expand All @@ -74,10 +153,45 @@ def _download_updates(self, last_updated: datetime.datetime) -> Generator[tuple[
if total_results:
self.logger.debug(f"discovered {total_results} updated CVEs")

yield from self._unwrap_records(response)
yield from self._unwrap_records(response, writer)

def _unwrap_records(self, response: dict[str, Any]) -> Generator[tuple[str, dict[str, Any]], Any, None]:
def _unwrap_records(
self,
response: dict[str, Any],
writer: result.Writer,
) -> Generator[tuple[str, dict[str, Any]], Any, None]:
for vuln in response["vulnerabilities"]:
cve_id = vuln["cve"]["id"]
year = cve_id.split("-")[1]
yield os.path.join(year, cve_id), vuln
record_id = cve_to_id(cve_id)

# keep input for future overrides
writer.write(record_id, self.schema, vuln)

# apply overrides to output
yield record_id, self._apply_override(cve_id=cve_id, record=vuln)

def _apply_override(self, cve_id: str, record: dict[str, Any]) -> dict[str, Any]:
override = self.overrides.cve(cve_id)
if override:
self.logger.debug(f"applying override for {cve_id}")
# ignore empty overrides
if override is None or "cve" not in override:
return record
# explicitly only support CPE configurations for now and always override the
# original record configurations. Can figure out more complicated scenarios
# later if needed
if "configurations" not in override["cve"]:
return record

record["cve"]["configurations"] = override["cve"]["configurations"]

return record


def cve_to_id(cve: str) -> str:
year = cve.split("-")[1]
return os.path.join(year, cve)


def id_to_cve(cve_id: str) -> str:
return cve_id.split("/")[1]
109 changes: 109 additions & 0 deletions src/vunnel/providers/nvd/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import glob
import logging
import os
import tarfile
from typing import TYPE_CHECKING, Any

from orjson import loads

from vunnel.utils import http

if TYPE_CHECKING:
from vunnel.workspace import Workspace


class NVDOverrides:
__file_name__ = "nvd-overrides.tar.gz"
__extract_name__ = "nvd-overrides"

def __init__( # noqa: PLR0913
self,
enabled: bool,
url: str,
workspace: Workspace,
logger: logging.Logger | None = None,
download_timeout: int = 125,
):
self.enabled = enabled
self.__url__ = url
self.workspace = workspace
self.download_timeout = download_timeout
if not logger:
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger
self.__filepaths_by_cve__: dict[str, str] | None = None

@property
def url(self) -> str:
return self.__url__

def download(self) -> None:
if not self.enabled:
self.logger.debug("overrides are not enabled, skipping download...")
return

req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout)

file_path = os.path.join(self.workspace.input_path, self.__file_name__)
with open(file_path, "wb") as fp:
for chunk in req.iter_content():
fp.write(chunk)

untar_file(file_path, self._extract_path)

@property
def _extract_path(self) -> str:
return os.path.join(self.workspace.input_path, self.__extract_name__)

def _build_files_by_cve(self) -> dict[str, Any]:
filepaths_by_cve__: dict[str, str] = {}
for path in glob.glob(os.path.join(self._extract_path, "**/data/**/", "CVE-*.json"), recursive=True):
cve_id = os.path.basename(path).removesuffix(".json").upper()
filepaths_by_cve__[cve_id] = path

return filepaths_by_cve__

def cve(self, cve_id: str) -> dict[str, Any] | None:
if not self.enabled:
return None

if self.__filepaths_by_cve__ is None:
self.__filepaths_by_cve__ = self._build_files_by_cve()

# TODO: implement in-memory index
path = self.__filepaths_by_cve__.get(cve_id.upper())
if path and os.path.exists(path):
with open(path) as f:
return loads(f.read())
return None

def cves(self) -> list[str]:
if not self.enabled:
return []

if self.__filepaths_by_cve__ is None:
self.__filepaths_by_cve__ = self._build_files_by_cve()

return list(self.__filepaths_by_cve__.keys())


def untar_file(file_path: str, extract_path: str) -> None:
with tarfile.open(file_path, "r:gz") as tar:

def filter_path_traversal(tarinfo: tarfile.TarInfo, path: str) -> tarfile.TarInfo | None:
# we do not expect any relative file paths that would result in the clean
# path being different from the original path
# e.g.
# expected: results/results.db
# unexpected: results/../../../../etc/passwd
# we filter (drop) any such entries

if tarinfo.name != os.path.normpath(tarinfo.name):
return None
return tarinfo

# note: we have a filter that drops any entries that would result in a path traversal
# which is what S202 is referring to (linter isn't smart enough to understand this)
tar.extractall(path=extract_path, filter=filter_path_traversal) # noqa: S202
Loading

0 comments on commit f4dbceb

Please sign in to comment.