Skip to content

Commit

Permalink
fix CR and hier, twice corrected (SMTorg#647)
Browse files Browse the repository at this point in the history
* fix cr and update doc

* update tests

* update assert
  • Loading branch information
Paul-Saves authored Sep 18, 2024
1 parent 65aad0d commit b2f9b40
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 29 deletions.
6 changes: 4 additions & 2 deletions doc/_src_docs/applications/Mixed_Hier_surr.rst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions doc/_src_docs/applications/Mixed_Hier_surr.rstx
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
.. _Mixed Integer and hierarchical Surrogates:

Mixed integer surrogate
=======================
Mixed Integer and hierarchical Surrogates
=========================================

Mixed Integer Surrogates
------------------------
To use a surrogate with mixed integer constraints, the user instantiates a ``MixedIntegerSurrogateModel`` with the given surrogate.
The ``MixedIntegerSurrogateModel`` implements the ``SurrogateModel`` interface and decorates the given surrogate while respecting integer and categorical types.
They are various surrogate models implemented that are described below.
Expand Down
16 changes: 8 additions & 8 deletions smt/applications/tests/test_mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,8 @@ def test_hierarchical_variables_Goldstein(self):

y_sv = sm.predict_variances(Xt)[:, 0]
var_RMSE = np.linalg.norm(y_sv) / len(Yt)
self.assertTrue(pred_RMSE < 1e-7)
print("Pred_RMSE", pred_RMSE)
self.assertTrue(var_RMSE < 1e-7)
self.assertLess(pred_RMSE, 1e-7)
self.assertLess(var_RMSE, 1e-7)
self.assertTrue(
np.linalg.norm(
sm.predict_values(
Expand Down Expand Up @@ -1040,6 +1039,8 @@ def test_hierarchical_design_space_example_all_categorical_decreed(self):
for mixint_kernel in [
MixIntKernelType.CONT_RELAX,
MixIntKernelType.GOWER,
MixIntKernelType.COMPOUND_SYMMETRY,
MixIntKernelType.EXP_HOMO_HSPHERE,
MixIntKernelType.HOMO_HSPHERE,
]:
sm = MixedIntegerKrigingModel(
Expand All @@ -1057,10 +1058,10 @@ def test_hierarchical_design_space_example_all_categorical_decreed(self):
sm.train()
y_s = sm.predict_values(Xt)[:, 0]
_pred_RMSE = np.linalg.norm(y_s - Yt) / len(Yt)

self.assertLess(_pred_RMSE, 1e-7)
y_sv = sm.predict_variances(Xt)[:, 0]
_var_RMSE = np.linalg.norm(y_sv) / len(Yt)

self.assertLess(_var_RMSE, 1e-7)
np.testing.assert_almost_equal(
sm.predict_values(
np.array(
Expand Down Expand Up @@ -1323,9 +1324,8 @@ def test_hierarchical_variables_NN(self):

y_sv = sm.predict_variances(Xt)[:, 0]
var_RMSE = np.linalg.norm(y_sv) / len(Yt)
self.assertTrue(pred_RMSE < 1e-7)
print("Pred_RMSE", pred_RMSE)
self.assertTrue(var_RMSE < 1e-7)
self.assertLess(pred_RMSE, 1e-7)
self.assertLess(var_RMSE, 1e-7)
np.testing.assert_almost_equal(
sm.predict_values(
np.array(
Expand Down
17 changes: 0 additions & 17 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,6 @@ def _new_train(self):
listcatdecreed = self.design_space.is_conditionally_acting[
self.cat_features
]
if np.any(listcatdecreed):
D = self._correct_distances_cat_decreed(
D,
is_acting,
listcatdecreed,
self.ij,
mixint_type=MixIntKernelType.CONT_RELAX,
)

# Center and scale X_cont and y
(
Expand Down Expand Up @@ -1485,15 +1477,6 @@ def _predict_init(self, x, is_acting):
listcatdecreed = self.design_space.is_conditionally_acting[
self.cat_features
]
if np.any(listcatdecreed):
dx = self._correct_distances_cat_decreed(
dx,
is_acting,
listcatdecreed,
ij,
is_acting_y=self.is_acting_train,
mixint_type=MixIntKernelType.CONT_RELAX,
)

Lij, _ = cross_levels(
X=x, ij=ij, design_space=self.design_space, y=self.X_train
Expand Down

0 comments on commit b2f9b40

Please sign in to comment.