Skip to content

Commit

Permalink
Added rcount/ccount to plot_surface(), providing an alternative to rs…
Browse files Browse the repository at this point in the history
…tride/cstride
  • Loading branch information
WeatherGod committed Dec 13, 2016
1 parent 984e9b0 commit 58d3de2
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 17 deletions.
15 changes: 15 additions & 0 deletions doc/users/whats_new/plot_surface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
`rcount` and `ccount` for `plot_surface()`
------------------------------------------

As of v2.0, mplot3d's :func:`~mpl_toolkits.mplot3d.axes3d.plot_surface` now
accepts `rcount` and `ccount` arguments for controlling the sampling of the
input data for plotting. These arguments specify the maximum number of
evenly spaced samples to take from the input data. These arguments are
also the new default sampling method for the function, and is
considered a style change.

The old `rstride` and `cstride` arguments, which specified the size of the
evenly spaced samples, become the default when 'classic' mode is invoked,
and are still available for use. There are no plans for deprecating these
arguments.

2 changes: 1 addition & 1 deletion examples/mplot3d/surface3d_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Z = np.sin(R)

# Plot the surface.
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
linewidth=0, antialiased=False)

# Customize the z axis.
Expand Down
2 changes: 1 addition & 1 deletion examples/mplot3d/surface3d_demo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
z = 10 * np.outer(np.ones(np.size(u)), np.cos(v))

# Plot the surface
ax.plot_surface(x, y, z, rstride=4, cstride=4, color='b')
ax.plot_surface(x, y, z, color='b')

plt.show()
3 changes: 1 addition & 2 deletions examples/mplot3d/surface3d_demo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
colors[x, y] = colortuple[(x + y) % len(colortuple)]

# Plot the surface with face colors taken from the array we made.
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
linewidth=0)
surf = ax.plot_surface(X, Y, Z, facecolors=colors, linewidth=0)

# Customize the z axis.
ax.set_zlim(-1, 1)
Expand Down
2 changes: 1 addition & 1 deletion examples/mplot3d/surface3d_radial_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
X, Y = R*np.cos(P), R*np.sin(P)

# Plot the surface.
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.YlGnBu_r)
ax.plot_surface(X, Y, Z, cmap=plt.cm.YlGnBu_r)

# Tweak the limits and add latex math labels.
ax.set_zlim(0, 1)
Expand Down
89 changes: 81 additions & 8 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,15 +1553,28 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
The `rstride` and `cstride` kwargs set the stride used to
sample the input data to generate the graph. If 1k by 1k
arrays are passed in the default values for the strides will
result in a 100x100 grid being plotted.
arrays are passed in, the default values for the strides will
result in a 100x100 grid being plotted. Defaults to 10.
Raises a ValueError if both stride and count kwargs
(see next section) are provided.
The `rcount` and `ccount` kwargs supersedes `rstride` and
`cstride` for default sampling method for surface plotting.
These arguments will determine at most how many evenly spaced
samples will be taken from the input data to generate the graph.
This is the default sampling method unless using the 'classic'
style. Will raise ValueError if both stride and count are
specified.
Added in v2.0.0.
============= ================================================
Argument Description
============= ================================================
*X*, *Y*, *Z* Data values as 2D arrays
*rstride* Array row stride (step size), defaults to 10
*cstride* Array column stride (step size), defaults to 10
*rstride* Array row stride (step size)
*cstride* Array column stride (step size)
*rcount* Use at most this many rows, defaults to 50
*ccount* Use at most this many columns, defaults to 50
*color* Color of the surface patches
*cmap* A colormap for the surface patches.
*facecolors* Face colors for the individual patches
Expand All @@ -1582,8 +1595,30 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
X, Y, Z = np.broadcast_arrays(X, Y, Z)
rows, cols = Z.shape

has_stride = 'rstride' in kwargs or 'cstride' in kwargs
has_count = 'rcount' in kwargs or 'ccount' in kwargs

if has_stride and has_count:
raise ValueError("Cannot specify both stride and count arguments")

rstride = kwargs.pop('rstride', 10)
cstride = kwargs.pop('cstride', 10)
rcount = kwargs.pop('rcount', 50)
ccount = kwargs.pop('ccount', 50)

if rcParams['_internal.classic_mode']:
# Strides have priority over counts in classic mode.
# So, only compute strides from counts
# if counts were explicitly given
if has_count:
rstride = int(np.ceil(rows / rcount))
cstride = int(np.ceil(cols / ccount))
else:
# If the strides are provided then it has priority.
# Otherwise, compute the strides from the counts.
if not has_stride:
rstride = int(np.ceil(rows / rcount))
cstride = int(np.ceil(cols / ccount))

if 'facecolors' in kwargs:
fcolors = kwargs.pop('facecolors')
Expand Down Expand Up @@ -1733,7 +1768,21 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
The `rstride` and `cstride` kwargs set the stride used to
sample the input data to generate the graph. If either is 0
the input data in not sampled along this direction producing a
3D line plot rather than a wireframe plot.
3D line plot rather than a wireframe plot. The stride arguments
are only used by default if in the 'classic' mode. They are
now superseded by `rcount` and `ccount`. Will raise ValueError
if both stride and count are used.
` The `rcount` and `ccount` kwargs supersedes `rstride` and
`cstride` for default sampling method for wireframe plotting.
These arguments will determine at most how many evenly spaced
samples will be taken from the input data to generate the graph.
This is the default sampling method unless using the 'classic'
style. Will raise ValueError if both stride and count are
specified. If either is zero, then the input data is not sampled
along this direction, producing a 3D line plot rather than a
wireframe plot.
Added in v2.0.0.
========== ================================================
Argument Description
Expand All @@ -1742,6 +1791,8 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
*Z*
*rstride* Array row stride (step size), defaults to 1
*cstride* Array column stride (step size), defaults to 1
*rcount* Use at most this many rows, defaults to 50
*ccount* Use at most this many columns, defaults to 50
========== ================================================
Keyword arguments are passed on to
Expand All @@ -1750,15 +1801,37 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
Returns a :class:`~mpl_toolkits.mplot3d.art3d.Line3DCollection`
'''

rstride = kwargs.pop("rstride", 1)
cstride = kwargs.pop("cstride", 1)

had_data = self.has_data()
Z = np.atleast_2d(Z)
# FIXME: Support masked arrays
X, Y, Z = np.broadcast_arrays(X, Y, Z)
rows, cols = Z.shape

has_stride = 'rstride' in kwargs or 'cstride' in kwargs
has_count = 'rcount' in kwargs or 'ccount' in kwargs

if has_stride and has_count:
raise ValueError("Cannot specify both stride and count arguments")

rstride = kwargs.pop('rstride', 1)
cstride = kwargs.pop('cstride', 1)
rcount = kwargs.pop('rcount', 50)
ccount = kwargs.pop('ccount', 50)

if rcParams['_internal.classic_mode']:
# Strides have priority over counts in classic mode.
# So, only compute strides from counts
# if counts were explicitly given
if has_count:
rstride = int(np.ceil(rows / rcount)) if rcount else 0
cstride = int(np.ceil(cols / ccount)) if ccount else 0
else:
# If the strides are provided then it has priority.
# Otherwise, compute the strides from the counts.
if not has_stride:
rstride = int(np.ceil(rows / rcount)) if rcount else 0
cstride = int(np.ceil(cols / ccount)) if ccount else 0

# We want two sets of lines, one running along the "rows" of
# Z and another set of lines running along the "columns" of Z.
# This transpose will make it easy to obtain the columns.
Expand Down
21 changes: 17 additions & 4 deletions lib/mpl_toolkits/tests/test_mplot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def f(t):
R = np.sqrt(X ** 2 + Y ** 2)
Z = np.sin(R)

surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40,
linewidth=0, antialiased=False)

ax.set_zlim3d(-1, 1)
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_surface3d():
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X ** 2 + Y ** 2)
Z = np.sin(R)
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40, cmap=cm.coolwarm,
lw=0, antialiased=False)
ax.set_zlim(-1.01, 1.01)
fig.colorbar(surf, shrink=0.5, aspect=5)
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_wireframe3d():
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, Z = axes3d.get_test_data(0.05)
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=13)


@image_comparison(baseline_images=['wireframe3dzerocstride'], remove_text=True,
Expand All @@ -203,7 +203,7 @@ def test_wireframe3dzerocstride():
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, Z = axes3d.get_test_data(0.05)
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=0)
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=0)


@image_comparison(baseline_images=['wireframe3dzerorstride'], remove_text=True,
Expand All @@ -214,6 +214,7 @@ def test_wireframe3dzerorstride():
X, Y, Z = axes3d.get_test_data(0.05)
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=10)


@cleanup
def test_wireframe3dzerostrideraises():
fig = plt.figure()
Expand All @@ -222,6 +223,18 @@ def test_wireframe3dzerostrideraises():
with assert_raises(ValueError):
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=0)


@cleanup
def test_mixedsamplesraises():
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, Z = axes3d.get_test_data(0.05)
with assert_raises(ValueError):
ax.plot_wireframe(X, Y, Z, rstride=10, ccount=50)
with assert_raises(ValueError):
ax.plot_surface(X, Y, Z, cstride=50, rcount=10)


@image_comparison(baseline_images=['quiver3d'], remove_text=True)
def test_quiver3d():
fig = plt.figure()
Expand Down

0 comments on commit 58d3de2

Please sign in to comment.