SHOGUN  v3.2.0
SGDQN.h
Go to the documentation of this file.
1 #ifndef _SGDQN_H___
2 #define _SGDQN_H___
3 
4 /*
5  SVM with Quasi-Newton stochastic gradient
6  Copyright (C) 2009- Antoine Bordes
7 
8  This program is free software; you can redistribute it and/or
9  modify it under the terms of the GNU Lesser General Public
10  License as published by the Free Software Foundation; either
11  version 2.1 of the License, or (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU Lesser General Public
19  License along with this library; if not, write to the Free Software
20  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21 
22  Shogun adjustments (w) 2011 Siddharth Kherada
23 */
24 
25 #include <shogun/lib/common.h>
28 #include <shogun/labels/Labels.h>
30 
31 namespace shogun
32 {
34 class CSGDQN : public CLinearMachine
35 {
36  public:
37 
40 
42  CSGDQN();
43 
48  CSGDQN(float64_t C);
49 
56  CSGDQN(
57  float64_t C, CDotFeatures* traindat,
58  CLabels* trainlab);
59 
60  virtual ~CSGDQN();
61 
67 
76  virtual bool train(CFeatures* data=NULL);
77 
84  inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
85 
90  inline float64_t get_C1() { return C1; }
91 
96  inline float64_t get_C2() { return C2; }
97 
102  inline void set_epochs(int32_t e) { epochs=e; }
103 
108  inline int32_t get_epochs() { return epochs; }
109 
111  void compute_ratio(float64_t* W,float64_t* W_1,float64_t* B,float64_t* dst,int32_t dim,float64_t regularizer_lambda,float64_t loss);
112 
114  void combine_and_clip(float64_t* Bc,float64_t* B,int32_t dim,float64_t c1,float64_t c2,float64_t v1,float64_t v2);
115 
120  void set_loss_function(CLossFunction* loss_func);
121 
126  inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; }
127 
129  virtual const char* get_name() const { return "SGDQN"; }
130 
131  protected:
133  void calibrate();
134 
135  private:
136  void init();
137 
138  private:
139  float64_t t;
140  float64_t C1;
141  float64_t C2;
142  int32_t epochs;
143  int32_t skip;
144  int32_t count;
145 
146  CLossFunction* loss;
147 };
148 }
149 #endif
EMachineType
Definition: Machine.h:33
Class CLossFunction is the base class of all loss functions.
Definition: LossFunction.h:53
void set_epochs(int32_t e)
Definition: SGDQN.h:102
int32_t get_epochs()
Definition: SGDQN.h:108
void compute_ratio(float64_t *W, float64_t *W_1, float64_t *B, float64_t *dst, int32_t dim, float64_t regularizer_lambda, float64_t loss)
Definition: SGDQN.cpp:69
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
void set_loss_function(CLossFunction *loss_func)
Definition: SGDQN.cpp:62
virtual bool train(CFeatures *data=NULL)
Definition: SGDQN.cpp:93
void set_C(float64_t c_neg, float64_t c_pos)
Definition: SGDQN.h:84
void calibrate()
Definition: SGDQN.cpp:206
Features that support dot products among other operations.
Definition: DotFeatures.h:41
CLossFunction * get_loss_function()
Definition: SGDQN.h:126
double float64_t
Definition: common.h:48
#define SG_REF(x)
Definition: SGRefObject.h:34
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
Definition: LinearMachine.h:61
virtual const char * get_name() const
Definition: SGDQN.h:129
float64_t get_C1()
Definition: SGDQN.h:90
MACHINE_PROBLEM_TYPE(PT_BINARY)
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
class SGDQN
Definition: SGDQN.h:34
The class Features is the base class of all feature objects.
Definition: Features.h:62
void combine_and_clip(float64_t *Bc, float64_t *B, int32_t dim, float64_t c1, float64_t c2, float64_t v1, float64_t v2)
Definition: SGDQN.cpp:81
virtual EMachineType get_classifier_type()
Definition: SGDQN.h:66
virtual ~CSGDQN()
Definition: SGDQN.cpp:57
float64_t get_C2()
Definition: SGDQN.h:96

SHOGUN Machine Learning Toolbox - Documentation