Plot 1D functions

show_TensorFunction1D[source]

show_TensorFunction1D(x, y, y_hat=None, return_fig=False, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=<deprecated parameter>, edgecolors=None, plotnonfinite=False, data=None)

Displays the 1D function y(x), with optional predictions y_hat

x = torch.randn(10, 1)
y = 2*x + 3
y_hat = 2*x + 3 + 0.1*torch.randn(10, 1)
show_TensorFunction1D(x, y, marker='.')
show_TensorFunction1D(x, y, y_hat=y_hat, marker='.')