Skip to content

Commit

Permalink
Remove pybtex warning when importing PyBaMM (pybamm-team#4383)
Browse files Browse the repository at this point in the history
* Update error handling

* Minor updates

* Update src/pybamm/citations.py

* Update citations.py

* Update tests
  • Loading branch information
kratman authored Aug 27, 2024
1 parent 265dcd6 commit 3586377
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 76 deletions.
125 changes: 55 additions & 70 deletions src/pybamm/citations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
#
# Bibliographical information for PyBaMM
# Inspired by firedrake/PETSc citation workflow
# https://firedrakeproject.org/citing.html
#
import pybamm
import os
import warnings
Expand All @@ -25,29 +20,32 @@ class Citations:
>>> pybamm.print_citations("citations.txt")
"""

def __init__(self):
# Set of citation keys that have been registered
self._papers_to_cite = set()
_module_import_error = False
# Set of citation keys that have been registered
_papers_to_cite: set
# Set of unknown citations to parse with pybtex
_unknown_citations: set
# Dict mapping citation tags for use when registering citations
_citation_tags: dict

def __init__(self):
self._check_for_bibtex()
# Dict mapping citations keys to BibTex entries
self._all_citations: dict[str, str] = dict()

# Set of unknown citations to parse with pybtex
self._unknown_citations = set()

# Dict mapping citation tags for use when registering citations
self._citation_tags = dict()

self.read_citations()
self._reset()

def _check_for_bibtex(self):
try:
import_optional_dependency("pybtex")
except ModuleNotFoundError:
self._module_import_error = True

def _reset(self):
"""Reset citations to default only (only for testing purposes)"""
# Initialize empty papers to cite
self._papers_to_cite = set()
# Initialize empty set of unknown citations
self._unknown_citations = set()
# Initialize empty citation tags
self._citation_tags = dict()
# Register the PyBaMM paper and the NumPy paper
self.register("Sulzer2021")
Expand All @@ -66,24 +64,18 @@ def read_citations(self):
"""Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited
by passing a BibTeX citation to :meth:`register`.
"""
try:
if not self._module_import_error:
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.__path__[0], "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
self._add_citation(key, entry)
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
"Citations could not be read because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable citation reading."
)

def _add_citation(self, key, entry):
"""Adds `entry` to `self._all_citations` under `key`, warning the user if a
previous entry is overwritten
"""

try:
if not self._module_import_error:
Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
Expand All @@ -96,11 +88,6 @@ def _add_citation(self, key, entry):

# Add to database
self._all_citations[key] = new_citation
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
f"Could not add citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable adding citations."
)

def _add_citation_tag(self, key, entry):
"""Adds a tag for a citation key in the dict, which represents the name of the
Expand Down Expand Up @@ -154,7 +141,7 @@ def _parse_citation(self, key):
key: str
A BibTeX formatted citation
"""
try:
if not self._module_import_error:
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
Expand All @@ -165,21 +152,13 @@ def _parse_citation(self, key):

# Add and register all citations
for key, entry in bib_data.entries.items():
# Add to _all_citations dictionary
self._add_citation(key, entry)
# Add to _papers_to_cite set
self._papers_to_cite.add(key)
return
return
except PybtexError as error:
# Unable to parse / unknown key
raise KeyError(
f"Not a bibtex citation or known citation: {key}"
) from error
except ModuleNotFoundError: # pragma: no cover
pybamm.logger.warning(
f"Could not parse citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybamm[cite]' to enable citation parsing."
)

def _tag_citations(self):
"""Prints the citation tags for the citations that have been registered
Expand All @@ -193,7 +172,7 @@ def _tag_citations(self):
def print(self, filename=None, output_format="text", verbose=False):
"""Print all citations that were used for running simulations. The verbose
option is provided to print tags for citations in the output such that it can
can be seen where the citations were registered due to the use of PyBaMM models
be seen where the citations were registered due to the use of PyBaMM models
and solvers in the code.
.. note::
Expand Down Expand Up @@ -230,7 +209,7 @@ def print(self, filename=None, output_format="text", verbose=False):
"""
# Parse citations that were not known keys at registration, but do not
# fail if they cannot be parsed
try:
if not self._module_import_error:
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
Expand All @@ -244,26 +223,36 @@ def print(self, filename=None, output_format="text", verbose=False):
# delete the invalid citation from the set
self._unknown_citations.remove(key)

if output_format == "text":
citations = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
)
elif output_format == "bibtex":
citations = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)
cite_list = self.format_citations(output_format, pybtex)
self.write_citations(cite_list, filename, verbose)
else:
self.print_import_warning()

if filename is None:
print(citations)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(citations)
except ModuleNotFoundError: # pragma: no cover
def write_citations(self, cite_list, filename, verbose):
if filename is None:
print(cite_list)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(cite_list)

def format_citations(self, output_format, pybtex):
if output_format == "text":
cite_list = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
)
elif output_format == "bibtex":
cite_list = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)
return cite_list

def print_import_warning(self):
if self._module_import_error:
pybamm.logger.warning(
"Could not print citations because the 'pybtex' library is not installed. "
"Please, install 'pybamm[cite]' to print citations."
Expand All @@ -272,15 +261,11 @@ def print(self, filename=None, output_format="text", verbose=False):

def print_citations(filename=None, output_format="text", verbose=False):
"""See :meth:`Citations.print`"""
if verbose: # pragma: no cover
if filename is not None: # pragma: no cover
raise Exception(
"Verbose output is available only for the terminal and not for printing to files",
)
else:
citations.print(filename, output_format, verbose=True)
else:
pybamm.citations.print(filename, output_format)
if verbose and filename is not None: # pragma: no cover
raise Exception(
"Verbose output is available only for the terminal and not for printing to files",
)
pybamm.citations.print(filename, output_format, verbose)


citations = Citations()
3 changes: 0 additions & 3 deletions tests/unit/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#
# Tests the citations class.
#
import pytest

import pybamm
Expand Down
12 changes: 9 additions & 3 deletions tests/unit/test_citations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#
# Tests the citations class.
#
import pytest
import pybamm
import os
Expand Down Expand Up @@ -111,6 +108,15 @@ def test_input_validation(self):
with pytest.raises(TypeError):
pybamm.citations._add_citation(1001, Entry("misc"))

def test_pybtex_warning(self, caplog):
class CiteWithWarning(pybamm.Citations):
def __init__(self):
super().__init__()
self._module_import_error = True

CiteWithWarning().print_import_warning()
assert "Could not print citations" in caplog.text

def test_andersson_2019(self):
citations = pybamm.citations
citations._reset()
Expand Down

0 comments on commit 3586377

Please sign in to comment.