Skip to content

Commit

Permalink
wip jensen demo add axis projection #6
Browse files Browse the repository at this point in the history
  • Loading branch information
shensquared committed Jun 29, 2023
1 parent 29c4594 commit 5f66baf
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
6 changes: 3 additions & 3 deletions demos/Jensen.html

Large diffs are not rendered by default.

75 changes: 61 additions & 14 deletions demos/scripts/Jensen.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,45 @@
import numpy as np
import plotly.graph_objects as go

n = 200
# Generate random data
f = lambda alpha, x: alpha * x * x + x + 3 * x

# Create NumPy arrays
alphas = np.linspace(-3, 3, 9)
x = np.random.randn(200) + 1
x = np.random.randn(n) + 1
xbar = np.average(x)
f_of_x_bar = f(alpha, xbar)
y = {}
ybar = {}
for alpha in alphas:
# mean_of_x = np.average(x)
y[str(alpha)] = f(alpha, x)
this_y = f(alpha, x)
y[str(alpha)] = this_y
ybar[str(alpha)] = np.average(this_y)
# Create figure and scatter plot
fig = go.Figure(
data=go.Scatter(
x=x,
y=y[str(alphas[0])],
mode="markers",
marker=dict(size=10, colorscale="Viridis", showscale=True),
hovertemplate="x: %{x}<br>y: %{y}<br>scalar: %{marker.color}",
),
data=[
go.Scatter(
x=x,
y=y[str(alphas[0])],
mode="markers",
marker=dict(size=10, colorscale="Viridis", showscale=False),
hovertemplate="x: %{x}<br>y: %{y}",
),
go.Scatter(
x=x,
y=np.zeros(n),
mode="markers",
xaxis="x",
yaxis="y2",
),
go.Scatter(
x=np.zeros(n),
y=y[str(alphas[0])],
xaxis="x2",
yaxis="y",
),
],
layout=go.Layout(
# title="Jensen's Inequality",
# xaxis=dict(range=[0, 1]),
Expand Down Expand Up @@ -55,7 +75,7 @@
sliders=[
dict(
active=0,
currentvalue={"prefix": "Alpha: "},
currentvalue={"prefix": r"$\alpha$: "},
pad={"t": 50},
steps=[
dict(
Expand Down Expand Up @@ -85,16 +105,43 @@
marker=dict(
size=10,
),
)
),
go.Scatter(
x=x,
y=np.zeros(n),
mode="markers",
name="Subplot 1",
xaxis="x",
yaxis="y2",
),
go.Scatter(
x=np.zeros(n),
y=y[str(i)],
mode="markers",
xaxis="x2",
yaxis="y",
),
],
)
for i in alphas
],
)

# Show the figure
fig.update_layout(title_text="Jensen's Inequality")
fig.show()
fig.update_layout(
title="Jensen's Inequality",
xaxis=dict(title="Sampled x", domain=[0, 0.95]),
yaxis=dict(title=r"$f(x)$", domain=[0.1, 1], tickformat="$%s$"),
xaxis2=dict(
title=None, domain=[1 - 0.03, 1], range=(-0.1, 0.1), showticklabels=False
),
yaxis2=dict(title=None, domain=[0, 0.03], range=(-0.1, 0.1), showticklabels=False),
showlegend=False,
)


# Create the figure
# fig.show()

import os

Expand Down

0 comments on commit 5f66baf

Please sign in to comment.