Class ComputationGraph<T>

java.lang.Object
org.episteme.core.mathematics.ml.neural.ComputationGraph<T>
Type Parameters:
T - the data type.

public class ComputationGraph<T> extends Object
Represents a neural network as a sequence of layers (Computational Graph).

This class orchestrates the forward and backward passes across all layers.

Since:
2.0
Author:
Silvere Martin-Michiellot, Gemini AI (Google DeepMind)
  • Constructor Details

    • ComputationGraph

      public ComputationGraph()
  • Method Details

    • add

      public ComputationGraph<T> add(Layer<T> layer)
      Adds a layer to the network.
      Parameters:
      layer - the layer to add.
      Returns:
      this graph for chaining.
    • predict

      public Tensor<T> predict(Tensor<T> input)
      Performs inference (forward pass) on the input.
      Parameters:
      input - the input tensor.
      Returns:
      the output tensor (prediction).
    • trainStep

      public void trainStep(Tensor<T> input, Tensor<T> target, Optimizer<T> optimizer)
      Performs a training step (forward + backward + optimize).
      Parameters:
      input - input batch.
      target - target values.
      optimizer - optimizer to update weights.
    • getLayers

      public List<Layer<T>> getLayers()
      Returns the list of layers.