Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Refit does not work for regression (Python API) #6737

Open
bhvieira opened this issue Dec 6, 2024 · 6 comments · May be fixed by #6753
Open

[python-package] Refit does not work for regression (Python API) #6737

bhvieira opened this issue Dec 6, 2024 · 6 comments · May be fixed by #6753
Labels

Comments

@bhvieira
Copy link

bhvieira commented Dec 6, 2024

Description

Refit does not work for regression (Python API)

Reproducible example

import lightgbm as lgb
import numpy as np
import pandas as pd

# Generate synthetic data
np.random.seed(42)
X_train = np.random.rand(100, 5)  # Training features
y_train = X_train @ np.array([1.5, -2.0, 1.0, 0.5, -1.0]) + np.random.normal(scale=0.1, size=100)  # Training targets

# Prepare data for LightGBM
train_data = lgb.Dataset(X_train, label=y_train)

# Define training parameters
params = {
    "objective": "regression",
    "verbosity": -1,
}

# Train the model
print("Training initial model...")
model = lgb.train(
    params,
    train_data,
    num_boost_round=1,
)

# Generate new data for refitting
X_new = np.random.rand(60, 5)  # New training features
y_new = X_new @ np.array([1.5, -2.0, 1.0, 0.5, -1.0]) + np.random.normal(scale=0.1, size=60)  # New training targets

# Refit the model
print("Refitting the model with new data...")
model_refit = model.refit(X_new, y_new)
# ValueError: not enough values to unpack (expected 2, got 1)

Environment info

LightGBM version or commit hash:

Command(s) you used to install LightGBM

poetry add lightgbm

name = "lightgbm"
version = "4.3.0"
description = "LightGBM Python Package"
category = "main"
optional = false
python-versions = ">=3.6"

Additional Comments

@jameslamb jameslamb changed the title Refit does not work for regression (Python API) [python-package] Refit does not work for regression (Python API) Dec 6, 2024
@jameslamb
Copy link
Collaborator

jameslamb commented Dec 6, 2024

Thanks for using LightGBM and for the clear report. This definitely looks like a bug!

Here's a smaller reproducible example, and logs with traceback showing the location of the error.

import lightgbm as lgb
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=10_000, n_features=10)

dtrain = lgb.Dataset(X, label=y)

model = lgb.train(
    params={
        "objective": "regression",
        "verbosity": -1
    },
    train_set=dtrain,
    num_boost_round=1,
)

model_refit = model.refit(X, y)

Traceback (most recent call last):
File "", line 1, in
File "/Users/jlamb/miniforge3/envs/lgb-dev/lib/python3.11/site-packages/lightgbm/basic.py", line 4866, in refit
nrow, ncol = leaf_preds.shape
^^^^^^^^^^
ValueError: not enough values to unpack (expected 2, got 1)

output of 'conda env export' (click me)
name: lgb-dev
channels:
  - conda-forge
dependencies:
  - aiobotocore=2.15.1=pyhd8ed1ab_0
  - aiohappyeyeballs=2.4.2=pyhd8ed1ab_0
  - aiohttp=3.10.8=py311h460d6c5_0
  - aioitertools=0.12.0=pyhd8ed1ab_0
  - aiosignal=1.3.1=pyhd8ed1ab_0
  - alabaster=0.7.16=pyhd8ed1ab_0
  - annotated-types=0.7.0=pyhd8ed1ab_0
  - atk-1.0=2.38.0=hd03087b_2
  - attrs=24.2.0=pyh71513ae_0
  - aws-c-auth=0.7.22=hec39e38_2
  - aws-c-cal=0.6.14=h5db4892_1
  - aws-c-common=0.9.19=h99b78c6_0
  - aws-c-compression=0.2.18=h5db4892_6
  - aws-c-event-stream=0.4.2=h5eab607_12
  - aws-c-http=0.8.1=had10953_17
  - aws-c-io=0.14.8=hb5a7b21_5
  - aws-c-mqtt=0.10.4=h78534b8_4
  - aws-c-s3=0.5.9=h1755d02_3
  - aws-c-sdkutils=0.1.16=h5db4892_2
  - aws-checksums=0.1.18=h5db4892_6
  - aws-crt-cpp=0.26.9=h03bff2b_0
  - aws-sdk-cpp=1.11.329=hb37a6d0_3
  - babel=2.14.0=pyhd8ed1ab_0
  - backports.zoneinfo=0.2.1=py311h267d04e_9
  - boto3=1.35.23=pyhd8ed1ab_0
  - botocore=1.35.23=pyge310_1234567_0
  - breathe=4.35.0=pyhd8ed1ab_2
  - brotli=1.1.0=hb547adb_1
  - brotli-bin=1.1.0=hb547adb_1
  - brotli-python=1.1.0=py311ha891d26_1
  - bzip2=1.0.8=h93a5062_5
  - c-ares=1.28.1=h93a5062_0
  - ca-certificates=2024.8.30=hf0a4a13_0
  - cairo=1.18.0=hd1e100b_0
  - certifi=2024.8.30=pyhd8ed1ab_0
  - cffi=1.16.0=py311h4a08483_0
  - cfgv=3.3.1=pyhd8ed1ab_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - click=8.1.7=unix_pyh707e725_0
  - cloudpickle=3.0.0=pyhd8ed1ab_0
  - cmakelint=1.4.3=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - coverage=7.5.0=py311hd23d018_0
  - cpplint=1.6.0=pyhd8ed1ab_0
  - cycler=0.12.1=pyhd8ed1ab_0
  - distlib=0.3.8=pyhd8ed1ab_0
  - docutils=0.20.1=py311h267d04e_3
  - doxygen=1.10.0=h8fbad5d_0
  - dunamai=1.22.0=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - expat=2.6.2=hebf3989_0
  - filelock=3.14.0=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=h77eed37_1
  - fontconfig=2.14.2=h82840c6_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.53.0=py311hd3f4193_0
  - freetype=2.12.1=hadb7bae_2
  - fribidi=1.0.10=h27ca646_0
  - frozenlist=1.4.1=py311h460d6c5_1
  - fsspec=2024.9.0=pyhff2d567_0
  - gdk-pixbuf=2.42.11=h13c029f_0
  - gettext=0.22.5=h8fbad5d_2
  - gettext-tools=0.22.5=h8fbad5d_2
  - gflags=2.2.2=hc88da5d_1004
  - giflib=5.2.2=h93a5062_0
  - glog=0.7.1=heb240a5_0
  - gmp=6.3.0=hebf3989_1
  - graphite2=1.3.13=hebf3989_1003
  - graphviz=9.0.0=h3face73_1
  - gtk2=2.24.33=h7895bb2_4
  - gts=0.7.6=he42f4ea_4
  - h2=4.1.0=pyhd8ed1ab_0
  - harfbuzz=8.4.0=hbe0f7c0_0
  - hpack=4.0.0=pyh9f0ad1d_0
  - hyperframe=6.0.1=pyhd8ed1ab_0
  - hypothesis=6.115.2=pyha770c72_0
  - icu=73.2=hc8870d7_0
  - identify=2.5.36=pyhd8ed1ab_0
  - imagesize=1.4.1=pyhd8ed1ab_0
  - iniconfig=2.0.0=pyhd8ed1ab_0
  - jinja2=3.1.4=pyhd8ed1ab_0
  - jmespath=1.0.1=pyhd8ed1ab_0
  - joblib=1.4.2=pyhd8ed1ab_0
  - khronos-opencl-icd-loader=2024.10.24=h5505292_1
  - kiwisolver=1.4.5=py311he4fd1f5_1
  - krb5=1.21.3=h237132a_0
  - lcms2=2.16=ha0e7c42_0
  - lerc=4.0.0=h9a09cb3_0
  - libabseil=20240116.2=cxx17_hebf3989_0
  - libarrow=16.1.0=h28dd788_6_cpu
  - libarrow-acero=16.1.0=h00cdb27_6_cpu
  - libarrow-dataset=16.1.0=h00cdb27_6_cpu
  - libarrow-substrait=16.1.0=hc68f6b8_6_cpu
  - libasprintf=0.22.5=h8fbad5d_2
  - libasprintf-devel=0.22.5=h8fbad5d_2
  - libblas=3.9.0=22_osxarm64_openblas
  - libboost=1.84.0=h17eb2be_3
  - libbrotlicommon=1.1.0=hb547adb_1
  - libbrotlidec=1.1.0=hb547adb_1
  - libbrotlienc=1.1.0=hb547adb_1
  - libcblas=3.9.0=22_osxarm64_openblas
  - libcrc32c=1.1.2=hbdafb3b_0
  - libcurl=8.8.0=h7b6f9a7_1
  - libcxx=18.1.8=h3ed4263_6
  - libdeflate=1.20=h93a5062_0
  - libedit=3.1.20191231=hc8eb9b7_2
  - libev=4.33=h93a5062_2
  - libevent=2.1.12=h2757513_1
  - libexpat=2.6.2=hebf3989_0
  - libffi=3.4.2=h3422bc3_5
  - libgd=2.3.3=hfdf3952_9
  - libgettextpo=0.22.5=h8fbad5d_2
  - libgettextpo-devel=0.22.5=h8fbad5d_2
  - libgfortran=5.0.0=13_2_0_hd922786_3
  - libgfortran5=13.2.0=hf226fd6_3
  - libglib=2.80.0=hfc324ee_6
  - libgoogle-cloud=2.24.0=hfe08963_0
  - libgoogle-cloud-storage=2.24.0=h3fa5b87_0
  - libgrpc=1.62.2=h9c18a4f_0
  - libiconv=1.17=h0d3ecfb_2
  - libintl=0.22.5=h8fbad5d_2
  - libintl-devel=0.22.5=h8fbad5d_2
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=22_osxarm64_openblas
  - liblightgbm=4.5.0=cpu_h7ba702d_3
  - libllvm14=14.0.6=hd1a9a77_4
  - libnghttp2=1.58.0=ha4dd798_1
  - libopenblas=0.3.27=openmp_h6c19121_0
  - libparquet=16.1.0=hcf52c46_6_cpu
  - libpng=1.6.43=h091b4b1_0
  - libprotobuf=4.25.3=hbfab5d5_0
  - libre2-11=2023.09.01=h7b2c953_2
  - librsvg=2.58.0=hb3d354b_1
  - libsqlite=3.45.3=h091b4b1_0
  - libssh2=1.11.0=h7a5bd25_0
  - libthrift=0.19.0=h026a170_1
  - libtiff=4.6.0=h07db509_3
  - libutf8proc=2.8.0=h1a8c8d9_0
  - libwebp=1.3.2=hf30222e_1
  - libwebp-base=1.3.2=h93a5062_1
  - libxcb=1.15=hf346824_0
  - libxml2=2.12.6=h0d0cfa8_2
  - libzlib=1.2.13=h53f4e23_5
  - lightgbm=4.5.0=cpu_py_3
  - llvm-openmp=18.1.8=hde57baf_0
  - llvmlite=0.42.0=py311hf5d242d_1
  - lz4-c=1.9.4=hb7217d7_0
  - markdown-it-py=3.0.0=pyhd8ed1ab_0
  - markupsafe=2.1.5=py311h460d6c5_1
  - matplotlib=3.8.4=py311ha1ab1f8_2
  - matplotlib-base=3.8.4=py311h000fb6e_2
  - mdurl=0.1.2=pyhd8ed1ab_0
  - multidict=6.1.0=py311h426a4a9_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - mypy=1.13.0=py311hae2e1ce_0
  - mypy_extensions=1.0.0=pyha770c72_0
  - ncurses=6.4.20240210=h078ce10_0
  - nodeenv=1.8.0=pyhd8ed1ab_0
  - numba=0.59.1=py311h00351ea_0
  - numpy=1.26.4=py311h7125741_0
  - opencl-headers=2024.10.24=h286801f_0
  - openjpeg=2.5.2=h9f1df11_0
  - openssl=3.4.0=h39f12f2_0
  - orc=2.0.1=h47ade37_1
  - packaging=24.0=pyhd8ed1ab_0
  - pandas=2.2.2=py311h4b4568b_1
  - pandoc=3.1.13=hce30654_0
  - pango=1.52.2=hb067d4f_0
  - pcre2=10.43=h26f9a81_0
  - pillow=10.3.0=py311h0b5d0a1_0
  - pip=24.0=pyhd8ed1ab_0
  - pixman=0.43.4=hebf3989_0
  - platformdirs=4.2.1=pyhd8ed1ab_0
  - pluggy=1.5.0=pyhd8ed1ab_0
  - pre-commit=3.7.0=pyha770c72_0
  - psutil=5.9.8=py311h05b510d_0
  - pthread-stubs=0.4=h27ca646_1001
  - pyarrow=16.1.0=py311hf3b2ce4_1
  - pyarrow-core=16.1.0=py311hbc16ef1_1_cpu
  - pycparser=2.22=pyhd8ed1ab_0
  - pygments=2.18.0=pyhd8ed1ab_0
  - pyparsing=3.1.2=pyhd8ed1ab_0
  - pysocks=1.7.1=pyha2e5f31_6
  - pytest=8.2.2=pyhd8ed1ab_0
  - pytest-cov=5.0.0=pyhd8ed1ab_0
  - python=3.11.9=h932a869_0_cpython
  - python-dateutil=2.9.0=pyhd8ed1ab_0
  - python-graphviz=0.20.3=pyh717bed2_0
  - python-tzdata=2024.1=pyhd8ed1ab_0
  - python_abi=3.11=4_cp311
  - pytz=2024.1=pyhd8ed1ab_0
  - pyyaml=6.0.1=py311heffc1b2_1
  - re2=2023.09.01=h4cba328_2
  - readline=8.2=h92ec313_1
  - requests=2.32.3=pyhd8ed1ab_0
  - rstcheck=6.2.4=pyhd8ed1ab_0
  - rstcheck-core=1.2.1=pyhd8ed1ab_0
  - ruff=0.4.7=py311hd374d79_0
  - s3fs=2024.9.0=pyhd8ed1ab_0
  - s3transfer=0.10.2=pyhd8ed1ab_0
  - scipy=1.14.1=py311h2929bc6_0
  - setuptools=69.5.1=pyhd8ed1ab_0
  - shellcheck=0.10.0=hecfb573_0
  - shellingham=1.5.4=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - snappy=1.2.1=hd02b534_0
  - snowballstemmer=2.2.0=pyhd8ed1ab_0
  - sortedcontainers=2.4.0=pyhd8ed1ab_0
  - sphinx=7.1.2=pyhd8ed1ab_0
  - sphinx_rtd_theme=2.0.0=pyha770c72_0
  - sphinxcontrib-applehelp=2.0.0=pyhd8ed1ab_0
  - sphinxcontrib-devhelp=2.0.0=pyhd8ed1ab_0
  - sphinxcontrib-htmlhelp=2.1.0=pyhd8ed1ab_0
  - sphinxcontrib-jquery=4.1=pyhd8ed1ab_0
  - sphinxcontrib-jsmath=1.0.1=pyhd8ed1ab_0
  - sphinxcontrib-qthelp=2.0.0=pyhd8ed1ab_0
  - sphinxcontrib-serializinghtml=1.1.10=pyhd8ed1ab_0
  - threadpoolctl=3.5.0=pyhc1e730c_0
  - tk=8.6.13=h5083fa2_1
  - toml=0.10.2=pyhd8ed1ab_0
  - tomli=2.0.1=pyhd8ed1ab_0
  - tornado=6.4.1=py311hd3f4193_0
  - typer=0.12.5=pyhd8ed1ab_0
  - typer-slim=0.12.5=pyhd8ed1ab_0
  - typer-slim-standard=0.12.5=hd8ed1ab_0
  - typing-extensions=4.11.0=hd8ed1ab_0
  - typing_extensions=4.11.0=pyha770c72_0
  - tzdata=2024a=h0c530f3_0
  - ukkonen=1.0.1=py311he4fd1f5_4
  - virtualenv=20.26.0=pyhd8ed1ab_0
  - wheel=0.43.0=pyhd8ed1ab_1
  - wrapt=1.16.0=py311h460d6c5_1
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - yarl=1.13.1=py311h460d6c5_0
  - zlib=1.2.13=h53f4e23_5
  - zstandard=0.23.0=py311ha60cc69_1
  - zstd=1.5.6=hb46c0d2_0
  - pip:
      - altgraph==0.17.4
      - backports-tarfile==1.2.0
      - build==1.2.1
      - check-wheel-contents==0.6.0
      - contourpy==1.3.0.dev1
      - datatable==1.1.0
      - delocate==0.11.0
      - idna==3.8
      - importlib-metadata==8.4.0
      - jaraco-classes==3.4.0
      - jaraco-context==6.0.1
      - jaraco-functools==4.0.2
      - keyring==25.3.0
      - macholib==1.16.3
      - more-itertools==10.4.0
      - mpmath==1.3.0
      - networkx==3.3
      - nh3==0.2.18
      - pkginfo==1.10.0
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pydistcheck==0.7.1
      - pyproject-hooks==1.0.0
      - readme-renderer==44.0
      - requests-toolbelt==1.0.0
      - rfc3986==2.0.0
      - rich==13.8.0
      - scikit-learn==1.6.0rc1
      - sympy==1.13.2
      - torch==2.4.1
      - twine==5.1.1
      - urllib3==2.2.2
      - wheel-filename==1.4.1
      - xgboost==2.2.0.dev0
      - zipp==3.20.1
prefix: /Users/jlamb/miniforge3/envs/lgb-dev

That's coming from here:

nrow, ncol = leaf_preds.shape

I'm not sure how we haven't caught this before... this project has tests on Booster.refit() for regression models:

est2 = est.refit(x, label=y)

@jameslamb jameslamb added bug and removed question labels Dec 6, 2024
@jameslamb
Copy link
Collaborator

@bhvieira can you share here the output of running pip freeze (or whatever the poetry equivalent is)? I'm curious to see what version of numpy you have.

@bhvieira
Copy link
Author

bhvieira commented Dec 6, 2024

@jameslamb I have the whole poetry lock file, but it's quite long for my current project.

[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
category = "main"
optional = false
python-versions = ">=3.9"

@RektPunk
Copy link
Contributor

I think the issue might be with pred_leaf. It seems like the number of trees is too small, so the condition

if not is_sparse and preds.size != nrow:
if preds.size % nrow == 0:
preds = preds.reshape(nrow, -1)
wasn’t satisfied, which caused the problem with the shape.
I believe modifying the if condition such as

if not is_sparse and preds.size != nrow or pred_leaf: 

could fix it. Would it be okay if I give it a try?

@jameslamb
Copy link
Collaborator

Thanks for that... sure!

When you put up a pull request, place include a minimal unit test that would fail without your fix, to be sure we don't reintroduce this bug in the future.

@RektPunk
Copy link
Contributor

Thanks @jameslamb. I’ve created the PR and also added a simple test :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants