Skip to content

Latest commit

 

History

History
44 lines (37 loc) · 2.09 KB

README.md

File metadata and controls

44 lines (37 loc) · 2.09 KB

KAN MNIST

Similar to my other repo, but using a KAN instead of MLP, with as much of the same parameters as possible.

References:

Trouble:

Currently having some trouble right now, with torch.reshape in the train cell.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 20
     17 def test_acc():
     18     return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label']).float())
---> 20 results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc))
     21 results['train_acc'][-1], results['test_acc'][-1]

File c:\Users\Moo\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
    910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
    912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913     self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
    916 if opt == "LBFGS":
    917     optimizer.step(closure)

File c:\Users\Moo\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KAN.py:242, in KAN.update_grid_from_samples(self, x)
    219 '''
    220 update grid from samples
    221 
   (...)
    239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    240 '''
    241 for l in range(self.depth):
--> 242     self.forward(x)
    243     self.act_fun[l].update_grid_from_samples(self.acts[l])
...
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)