Class GraphNode<T>

java.lang.Object
org.episteme.core.mathematics.ml.neural.autograd.GraphNode<T>
All Implemented Interfaces:
Serializable

public class GraphNode<T> extends Object implements Serializable
Represents a node in the computation graph for Automatic Differentiation.

Nodes wrap Tensors and track the operation history to allow for reverse-mode automatic differentiation (backpropagation).

Since:
2.0
Author:
Silvere Martin-Michiellot, Gemini AI (Google DeepMind)
See Also:
  • Constructor Details

  • Method Details

    • backward

      public void backward()
      Triggers reverse-mode automatic differentiation. Computes the gradient of this node with respect to all its ancestors.
    • getData

      public Tensor<T> getData()
    • getGrad

      public Tensor<T> getGrad()
    • setGrad

      public void setGrad(Tensor<T> grad)
    • requiresGrad

      public boolean requiresGrad()
    • add

      public GraphNode<T> add(GraphNode<T> other)
      Element-wise addition.
    • multiply

      public GraphNode<T> multiply(GraphNode<T> other)
      Element-wise multiplication (Hadamard product).
    • matmul

      public GraphNode<T> matmul(GraphNode<T> other)
      Matrix multiplication (supports multidimensional tensors via einsum).
    • negate

      public GraphNode<T> negate()
      Element-wise negation.
    • subtract

      public GraphNode<T> subtract(GraphNode<T> other)
      Element-wise subtraction.
    • relu

      public GraphNode<T> relu()
      Rectified Linear Unit (ReLU).
    • sigmoid

      public GraphNode<T> sigmoid()
      Sigmoid activation function.
    • log

      public GraphNode<T> log()
      Natural logarithm.
    • sum

      public GraphNode<T> sum()
      Sum of all elements in the tensor.
    • scale

      public GraphNode<T> scale(double factor)
      Scale by a scalar factor.
    • mean

      public GraphNode<T> mean()
      Arithmetic mean of all elements.
    • setData

      public void setData(Tensor<T> data)
    • broadcast

      public GraphNode<T> broadcast(int... shape)
      Broadcasts the node to a new shape.
    • toString

      public String toString()
      Overrides:
      toString in class Object