Skip to content

Commit

Permalink
Fix sklearn regressor test?
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 22, 2019
1 parent c074416 commit fb7f49e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions keras/wrappers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def predict(self, x, **kwargs):
Predictions.
"""
kwargs = self.filter_sk_params(Sequential.predict, kwargs)
preds = self.model.predict(x, **kwargs)
return np.squeeze(preds, axis=len(preds.shape) - 1)
preds = np.array(self.model.predict(x, **kwargs))
if preds.shape[-1] == 1:
return np.squeeze(preds, axis=-1)
return preds

def score(self, x, y, **kwargs):
"""Returns the mean loss on the given test data and labels.
Expand Down

0 comments on commit fb7f49e

Please sign in to comment.