Skip to content

Commit

Permalink
cast column names on create predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
StpMax committed Feb 2, 2022
1 parent 5c8d775 commit 4f36633
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions mindsdb/api/mysql/mysql_proxy/mysql_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,30 @@ def answer_create_predictor(self, statement):
data_store.delete_datasource(ds_name)
raise

def get_column_in_case(columns, name):
candidates = []
name_lower = name.lower()
for column in columns:
if column.lower() == name_lower:
candidates.append(column)
if len(candidates) != 1:
raise Exception(f'Cant get appropriate cast column case. Columns: {columns}, column: {name}, candidates: {candidates}')
return candidates[0]

for i, p in enumerate(predict):
predict[i] = get_column_in_case(ds_column_names, p)

# Cast all column names to same case
if isinstance(kwargs.get('timeseries_settings'), dict):
order_by = kwargs['timeseries_settings'].get('order_by')
if order_by is not None:
for i, col in enumerate(order_by):
kwargs['timeseries_settings']['order_by'][i] = get_column_in_case(ds_column_names, col)
group_by = kwargs['timeseries_settings'].get('group_by')
if group_by is not None:
for i, col in enumerate(group_by):
kwargs['timeseries_settings']['group_by'][i] = get_column_in_case(ds_column_names, col)

model_interface.learn(predictor_name, ds, predict, ds_data['id'], kwargs=kwargs, delete_ds_on_fail=True)

self.packet(OkPacket).send()
Expand Down

0 comments on commit 4f36633

Please sign in to comment.