org.apache.mahout.classifier.sgd
Class LogisticModelParameters

java.lang.Object
  extended by org.apache.mahout.classifier.sgd.LogisticModelParameters
All Implemented Interfaces:
org.apache.hadoop.io.Writable
Direct Known Subclasses:
AdaptiveLogisticModelParameters

public class LogisticModelParameters
extends Object
implements org.apache.hadoop.io.Writable

Encapsulates everything we need to know about a model and how it reads and vectorizes its input. This encapsulation allows us to coherently save and restore a model from a file. This also allows us to keep command line arguments that affect learning in a coherent way.


Constructor Summary
LogisticModelParameters()
           
 
Method Summary
 OnlineLogisticRegression createRegression()
          Creates a logistic regression trainer using the parameters collected here.
 CsvRecordFactory getCsvRecordFactory()
          Returns a CsvRecordFactory compatible with this logistic model.
 double getLambda()
           
 double getLearningRate()
           
 int getMaxTargetCategories()
           
 int getNumFeatures()
           
 List<String> getTargetCategories()
           
 String getTargetVariable()
           
 Map<String,String> getTypeMap()
           
static LogisticModelParameters loadFrom(File in)
          Reads a model from a file.
static LogisticModelParameters loadFrom(InputStream in)
          Reads a model from a stream.
 void readFields(DataInput in)
           
 void saveTo(OutputStream out)
          Saves a model to an output stream.
 void setLambda(double lambda)
           
 void setLearningRate(double learningRate)
           
 void setMaxTargetCategories(int maxTargetCategories)
          Sets the number of target categories to be considered.
 void setNumFeatures(int numFeatures)
           
 void setTargetCategories(List<String> targetCategories)
           
 void setTargetVariable(String targetVariable)
          Sets the target variable.
 void setTypeMap(Iterable<String> predictorList, List<String> typeList)
          Sets the types of the predictors.
 void setTypeMap(Map<String,String> map)
           
 void setUseBias(boolean useBias)
           
 boolean useBias()
           
 void write(DataOutput out)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

LogisticModelParameters

public LogisticModelParameters()
Method Detail

getCsvRecordFactory

public CsvRecordFactory getCsvRecordFactory()
Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied in here is so that we have access to the list of target categories when it comes time to save the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will suffice.

Returns:
The CsvRecordFactory.

createRegression

public OnlineLogisticRegression createRegression()
Creates a logistic regression trainer using the parameters collected here.

Returns:
The newly allocated OnlineLogisticRegression object

saveTo

public void saveTo(OutputStream out)
            throws IOException
Saves a model to an output stream.

Throws:
IOException

loadFrom

public static LogisticModelParameters loadFrom(InputStream in)
                                        throws IOException
Reads a model from a stream.

Throws:
IOException

loadFrom

public static LogisticModelParameters loadFrom(File in)
                                        throws IOException
Reads a model from a file.

Throws:
IOException - If there is an error opening or closing the file.

write

public void write(DataOutput out)
           throws IOException
Specified by:
write in interface org.apache.hadoop.io.Writable
Throws:
IOException

readFields

public void readFields(DataInput in)
                throws IOException
Specified by:
readFields in interface org.apache.hadoop.io.Writable
Throws:
IOException

setTypeMap

public void setTypeMap(Iterable<String> predictorList,
                       List<String> typeList)
Sets the types of the predictors. This will later be used when reading CSV data. If you don't use the CSV data and convert to vectors on your own, you don't need to call this.

Parameters:
predictorList - The list of variable names.
typeList - The list of types in the format preferred by CsvRecordFactory.

setTargetVariable

public void setTargetVariable(String targetVariable)
Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.

Parameters:
targetVariable - The name of the target variable.

setMaxTargetCategories

public void setMaxTargetCategories(int maxTargetCategories)
Sets the number of target categories to be considered.

Parameters:
maxTargetCategories - The number of target categories.

setNumFeatures

public void setNumFeatures(int numFeatures)

setTargetCategories

public void setTargetCategories(List<String> targetCategories)

getTargetCategories

public List<String> getTargetCategories()

setUseBias

public void setUseBias(boolean useBias)

useBias

public boolean useBias()

getTargetVariable

public String getTargetVariable()

getTypeMap

public Map<String,String> getTypeMap()

setTypeMap

public void setTypeMap(Map<String,String> map)

getNumFeatures

public int getNumFeatures()

getMaxTargetCategories

public int getMaxTargetCategories()

getLambda

public double getLambda()

setLambda

public void setLambda(double lambda)

getLearningRate

public double getLearningRate()

setLearningRate

public void setLearningRate(double learningRate)


Copyright © 2008–2014 The Apache Software Foundation. All rights reserved.