|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.mahout.classifier.mlp.NeuralNetwork
public abstract class NeuralNetwork
AbstractNeuralNetwork defines the general operations for a neural network based model. Typically, all derivative models such as Multilayer Perceptron and Autoencoder consist of neurons and the weights between neurons.
Nested Class Summary | |
---|---|
static class |
NeuralNetwork.TrainingMethod
|
Field Summary | |
---|---|
protected String |
costFunctionName
|
protected int |
finalLayerIdx
|
protected List<Integer> |
layerSizeList
|
protected double |
learningRate
|
protected String |
modelPath
|
protected String |
modelType
|
protected double |
momentumWeight
|
protected List<Matrix> |
prevWeightUpdatesList
|
protected double |
regularizationWeight
|
protected List<String> |
squashingFunctionList
|
protected NeuralNetwork.TrainingMethod |
trainingMethod
|
protected List<Matrix> |
weightMatrixList
|
Constructor Summary | |
---|---|
NeuralNetwork()
The default constructor that initializes the learning rate, regularization weight, and momentum weight by default. |
|
NeuralNetwork(double learningRate,
double momentumWeight,
double regularizationWeight)
Initialize the NeuralNetwork by specifying learning rate, momentum weight and regularization weight. |
|
NeuralNetwork(String modelPath)
Initialize the NeuralNetwork by specifying the location of the model. |
Method Summary | |
---|---|
int |
addLayer(int size,
boolean isFinalLayer,
String squashingFunctionName)
Add a layer of neurons with specified size. |
protected Vector |
forward(int fromLayer,
Vector intermediateOutput)
Forward the calculation for one layer. |
int |
getLayerSize(int layer)
Get the size of a particular layer. |
protected List<Integer> |
getLayerSizeList()
Get the layer size list. |
double |
getLearningRate()
Get the value of learning rate. |
String |
getModelPath()
Get the model path. |
String |
getModelType()
Get the type of the model. |
double |
getMomentumWeight()
Get the momentum weight. |
Vector |
getOutput(Vector instance)
Get the output calculated by the model. |
protected List<Vector> |
getOutputInternal(Vector instance)
Calculate output internally, the intermediate output of each layer will be stored. |
double |
getRegularizationWeight()
Get the weight of the regularization. |
NeuralNetwork.TrainingMethod |
getTrainingMethod()
Get the training method. |
Matrix[] |
getWeightMatrices()
Get all the weight matrices. |
Matrix |
getWeightsByLayer(int layerIdx)
Get the weights between layer layerIdx and layerIdx + 1 |
void |
readFields(DataInput input)
Read the fields of the model from input. |
protected void |
readFromModel()
Read the model meta-data from the specified location. |
NeuralNetwork |
setCostFunction(String costFunction)
Set the cost function for the model. |
NeuralNetwork |
setLearningRate(double learningRate)
Set the degree of aggression during model training, a large learning rate can increase the training speed, but it also decreases the chance of model converge. |
void |
setModelPath(String modelPath)
Set the model path. |
NeuralNetwork |
setMomentumWeight(double momentumWeight)
Set the momentum weight for the model. |
NeuralNetwork |
setRegularizationWeight(double regularizationWeight)
Set the regularization weight. |
NeuralNetwork |
setTrainingMethod(NeuralNetwork.TrainingMethod method)
Set the training method. |
void |
setWeightMatrices(Matrix[] matrices)
Set the weight matrices. |
void |
setWeightMatrix(int index,
Matrix matrix)
Set the weight matrix for a specified layer. |
Matrix[] |
trainByInstance(Vector trainingInstance)
Get the updated weights using one training instance. |
void |
trainOnline(Vector trainingInstance)
Train the neural network incrementally with given training instance. |
void |
updateWeightMatrices(Matrix[] matrices)
Update the weight matrices with given matrices. |
void |
write(DataOutput output)
Write the fields of the model to output. |
void |
writeModelToFile()
Write the model data to specified location. |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected String modelType
protected String modelPath
protected double learningRate
protected double regularizationWeight
protected double momentumWeight
protected String costFunctionName
protected List<Integer> layerSizeList
protected NeuralNetwork.TrainingMethod trainingMethod
protected List<Matrix> weightMatrixList
protected List<Matrix> prevWeightUpdatesList
protected List<String> squashingFunctionList
protected int finalLayerIdx
Constructor Detail |
---|
public NeuralNetwork()
public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight)
learningRate
- The learning rate.momentumWeight
- The momentum weight.regularizationWeight
- The regularization weight.public NeuralNetwork(String modelPath)
modelPath
- The location that the model is stored.Method Detail |
---|
public String getModelType()
public NeuralNetwork setLearningRate(double learningRate)
learningRate
- Learning rate must be a non-negative value. Recommend in range (0, 0.5).
public double getLearningRate()
public NeuralNetwork setRegularizationWeight(double regularizationWeight)
regularizationWeight
- regularization must be in the range [0, 0.1).
public double getRegularizationWeight()
public NeuralNetwork setMomentumWeight(double momentumWeight)
momentumWeight
- momentumWeight must be in range [0, 0.5].
public double getMomentumWeight()
public NeuralNetwork setTrainingMethod(NeuralNetwork.TrainingMethod method)
method
- The training method, currently supports GRADIENT_DESCENT.
public NeuralNetwork.TrainingMethod getTrainingMethod()
public NeuralNetwork setCostFunction(String costFunction)
costFunction
- the name of the cost function. Currently supports
"Minus_Squared", "Cross_Entropy".public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName)
size
- The size of the layer. (bias neuron excluded)isFinalLayer
- If false, add a bias neuron.squashingFunctionName
- The squashing function for this layer, input
layer is f(x) = x by default.
public int getLayerSize(int layer)
layer
- The index of the layer, starting from 0.
protected List<Integer> getLayerSizeList()
public Matrix getWeightsByLayer(int layerIdx)
layerIdx
- The index of the layer.
Matrix
.public void updateWeightMatrices(Matrix[] matrices)
matrices
- The weight matrices, must be the same dimension as the
existing matrices.public void setWeightMatrices(Matrix[] matrices)
matrices
- The weight matrices, must be the same dimension of the
existing matrices.public void setWeightMatrix(int index, Matrix matrix)
index
- The index of the matrix, starting from 0 (between layer 0 and 1).matrix
- The instance of Matrix
.public Matrix[] getWeightMatrices()
public Vector getOutput(Vector instance)
instance
- The feature instance in form of Vector
, each dimension contains the value of the corresponding feature.
protected List<Vector> getOutputInternal(Vector instance)
instance
- The feature instance in form of Vector
, each dimension contains the value of the corresponding feature.
protected Vector forward(int fromLayer, Vector intermediateOutput)
fromLayer
- The index of the previous layer.intermediateOutput
- The intermediate output of previous layer.
public void trainOnline(Vector trainingInstance)
trainingInstance
- An training instance, including the features and the label(s). Its dimension must equals
to the size of the input layer (bias neuron excluded) + the size
of the output layer (a.k.a. the dimension of the labels).public Matrix[] trainByInstance(Vector trainingInstance)
trainingInstance
- An training instance, including the features and the label(s). Its dimension must equals
to the size of the input layer (bias neuron excluded) + the size
of the output layer (a.k.a. the dimension of the labels).
Matrix
list.protected void readFromModel() throws IOException
IOException
public void writeModelToFile() throws IOException
IOException
public void setModelPath(String modelPath)
modelPath
- The path of the model.public String getModelPath()
public void write(DataOutput output) throws IOException
output
- The output instance.
IOException
public void readFields(DataInput input) throws IOException
input
- The input instance.
IOException
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |