def boundary_surface(parameters, x_min, x_max, y_min, y_max, steps=120):
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, steps),
np.linspace(y_min, y_max, steps),
)
grid = np.c_[xx.ravel(), yy.ravel()].T
probs, _ = forward_propagation(grid, parameters)
zz = probs.reshape(xx.shape)
return xx, yy, zz
x_min, x_max = X_train_rows[:, 0].min() - 1.0, X_train_rows[:, 0].max() + 1.0
y_min, y_max = X_train_rows[:, 1].min() - 1.0, X_train_rows[:, 1].max() + 1.0
xx, yy, zz0 = boundary_surface(snapshots[0]["parameters"], x_min, x_max, y_min, y_max)
x_axis = xx[0]
y_axis = yy[:, 0]
contour = go.Contour(
x=x_axis,
y=y_axis,
z=zz0,
colorscale="RdBu",
opacity=0.5,
showscale=True,
zmin=0,
zmax=1,
colorbar=dict(title="P(class = 1)"),
)
scatter = go.Scatter(
x=X_train_rows[:, 0],
y=X_train_rows[:, 1],
mode="markers",
marker=dict(
color=y_train,
colorscale="Viridis",
size=8,
line=dict(color="white", width=1),
),
text=[f"class={label}" for label in y_train],
hovertemplate="x1=%{x:.2f}<br>x2=%{y:.2f}<br>%{text}<extra></extra>",
showlegend=False,
)
frames = []
for snapshot in snapshots:
_, _, zz = boundary_surface(snapshot["parameters"], x_min, x_max, y_min, y_max)
frames.append(
go.Frame(
name=str(snapshot["iteration"]),
data=[
go.Contour(
x=x_axis,
y=y_axis,
z=zz,
colorscale="RdBu",
opacity=0.5,
showscale=True,
zmin=0,
zmax=1,
)
],
)
)
slider_steps = [
{
"args": [[str(snapshot["iteration"])], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
"label": str(snapshot["iteration"]),
"method": "animate",
}
for snapshot in snapshots
]
boundary_fig = go.Figure(data=[contour, scatter], frames=frames)
boundary_fig.update_layout(
title="How the Decision Boundary Changes During Training",
template="plotly_white",
height=650,
xaxis_title="x1 (standardized)",
yaxis_title="x2 (standardized)",
updatemenus=[
{
"type": "buttons",
"buttons": [
{
"label": "Play",
"method": "animate",
"args": [None, {"frame": {"duration": 300, "redraw": True}, "fromcurrent": True}],
},
{
"label": "Pause",
"method": "animate",
"args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}],
},
],
"x": 0.02,
"y": 1.08,
}
],
sliders=[
{
"currentvalue": {"prefix": "iteration="},
"pad": {"t": 50},
"steps": slider_steps,
}
],
)
boundary_fig.show()