Skip to content

Commit

Permalink
Set default step for GLM and l2 penalization to a better value
Browse files Browse the repository at this point in the history
Useful for comparison with scikit.
  • Loading branch information
Mbompr committed Feb 23, 2018
1 parent 7af80fb commit f0888de
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tick/base/learner/learner_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from warnings import warn

import numpy as np
from tick.base import Base

from tick.base import Base
from tick.base_model import ModelLipschitz
from .learner_optim import LearnerOptim

Expand Down Expand Up @@ -167,11 +167,20 @@ def fit(self, X: object, y: np.array):
if self.step is None and self.solver in self._solvers_with_step:
if self.solver in self._solvers_with_linesearch:
self._solver_obj.linesearch = True
elif self.solver == 'svrg':
elif self.solver == 'svrg' or self.solver == 'saga':
L = self._model_obj.get_lip_max()
if self.penalty == 'l2':
L += 1. / self.C
mun = min(2 * self._model_obj.n_samples / self.C, L)
self.step = 1. / (2 * L + mun)
else:
self.step = 1. / L

if isinstance(self._model_obj, ModelLipschitz):
self.step = 1. / self._model_obj.get_lip_max()
else:
warn('SVRG step needs to be tuned manually', RuntimeWarning)
warn('SVRG and SAGA steps needs to be tuned manually',
RuntimeWarning)
self.step = 1.
elif self.solver == 'sgd':
warn('SGD step needs to be tuned manually', RuntimeWarning)
Expand Down

0 comments on commit f0888de

Please sign in to comment.