Skip to content

Commit

Permalink
Attempt at better variables names
Browse files Browse the repository at this point in the history
geom.plot         -> geom._plot_unit
geom._translation -> geom._aes_renames
layer             -> pinfo                 # (parameter to geom.plot)
geom._rename_aes  -> geom._do_aes_renames

**Also**
- Make `colour -> color` the only automatic renaming
  • Loading branch information
has2k1 committed Mar 31, 2014
1 parent ea6405d commit 63b2c61
Show file tree
Hide file tree
Showing 18 changed files with 184 additions and 171 deletions.
55 changes: 29 additions & 26 deletions ggplot/geoms/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class geom(object):
"""Base class of all Geoms"""
VALID_AES = {}
VALID_AES = set() # TODO: use DEFAULT_AES dict instead
REQUIRED_AES = set()
DEFAULT_PARAMS = dict()

Expand All @@ -19,54 +19,58 @@ class geom(object):
params = None

_groups = set()
_translations = dict()
_aes_renames = dict()

def __init__(self, *args, **kwargs):
# new dicts for each geom
self.aes, self.data = self._aes_and_data(args, kwargs)
self.aes, self.data = self._find_aes_and_data(args, kwargs)
self.manual_aes = {}
self.params = deepcopy(self.DEFAULT_PARAMS)
for k, v in kwargs.items():
if k in self.VALID_AES:
self.manual_aes[k] = v
elif k in self.DEFAULT_PARAMS:
self.params[k] = v

# NOTE: Deal with unknown parameters.
# Throw an exception or save them for
# the layer?

def plot_layer(self, data, ax):
self._verify_aesthetics(data)

# NOTE: This is the correct check however with aes
# set in ggplot(), self.aes is empty
# groups = groups & set(self.aes) & set(data.columns)
try:
groups = self._groups & set(data.columns)
except AttributeError:
groups = set()
# This should be correct when the layer passes
# a sanitized dataframe
groups = self._groups & set(data.columns)

for _data in self._get_grouped_data(data, groups):
_data = dict((k, v) for k, v in _data.items()
for pinfo in self._get_grouped_data(data, groups):
pinfo = dict((k, v) for k, v in pinfo.items() # at layer level!
if k in self.VALID_AES)
_data.update(self.manual_aes)
self._rename_aes(_data)
self.plot(_data, ax)
pinfo.update(self.manual_aes) # at layer level!!

self._do_aes_renames(pinfo)
self._plot_unit(pinfo, ax)

def __radd__(self, gg):
gg = deepcopy(gg)
gg.geoms.append(self)
return gg

def _verify_aesthetics(self, layer):
def _verify_aesthetics(self, data):
"""
Check if all the required aesthetics have been specified
Raise an Exception if an aesthetic is missing
"""
missing_aes = self.REQUIRED_AES - set(layer)
missing_aes = self.REQUIRED_AES - set(data.columns)
if missing_aes:
msg = '{} requires the following missing aesthetics: {}'
raise Exception(msg.format(
self.__class__.__name__, ', '.join(missing_aes)))

def _aes_and_data(self, args, kwargs):
def _find_aes_and_data(self, args, kwargs):
"""
Identify the aes and data objects.
Expand All @@ -77,6 +81,7 @@ def _aes_and_data(self, args, kwargs):
- kwargs is a dictionary
Note: This is a helper function for self.__init__
It modifies the kwargs
"""
passed_aes = {}
data = None
Expand All @@ -95,33 +100,31 @@ def _aes_and_data(self, args, kwargs):
if 'mapping' in kwargs and passed_aes:
raise Exception(aes_err)
elif not passed_aes and 'mapping' in kwargs:
passed_aes = kwargs['mapping']
passed_aes = kwargs.pop('mapping')

if data is None and 'data' in kwargs:
data = kwargs['data']
data = kwargs.pop('data')

valid_aes = {}
for k, v in passed_aes.items():
if k in self.VALID_AES:
valid_aes[k] = v
return valid_aes, data

def _rename_aes(self, layer):
def _do_aes_renames(self, layer):
"""
Convert ggplot2 API names to matplotlib names
"""
# apply to all geoms
_translations = {'colour': 'color', 'linetype': 'linestyle'}
if 'colour' in layer:
layer['color'] = layer.pop('colour')

def _rename_fn(old, new):
# TODO: Sort out potential cyclic renames
# e.g fill -> color, color -> edgecolor
for old, new in self._aes_renames.items():
if old in layer:
layer[new] = layer.pop(old)

for k, v in _translations.items():
_rename_fn(k, v)
for k, v in self._translations.items():
_rename_fn(k, v)

def _get_grouped_data(self, data, groups):
"""
Split the data into groups.
Expand Down
11 changes: 6 additions & 5 deletions ggplot/geoms/geom_abline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import pandas as pd

class geom_abline(geom):
VALID_AES = {'x', 'color', 'linestyle', 'alpha', 'size'}
VALID_AES = {'x', 'color', 'linetype', 'alpha', 'size'}
DEFAULT_PARAMS = {'stat': 'abline', 'position': 'identity', 'slope': 1.0, 'intercept': 0.0, 'label': ''}

_groups = {'color', 'linestyle', 'alpha'}
_aes_renames = {'linetype': 'linestyle'}

def plot(self, layer, ax):
x = layer.pop(x)
def _plot_unit(self, pinfo, ax):
x = pinfo.pop(x)
slope = self.params['slope']
intercept = self.params['intercept']
layer['label'] = self.params['label']
pinfo['label'] = self.params['label']
if isinstance(x[0], Timestamp):
ax.set_autoscale_on(False)
ax.plot(ax.get_xlim(),ax.get_ylim())
Expand All @@ -24,5 +25,5 @@ def plot(self, layer, ax):
step = ((stop-start)) / 100.0
x_rng = np.arange(start, stop, step)
y_rng = x_rng * slope + intercept
ax.plot(x_rng, y_rng, **layer)
ax.plot(x_rng, y_rng, **pinfo)

11 changes: 6 additions & 5 deletions ggplot/geoms/geom_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ class geom_area(geom):
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'stack'}

_groups = {'color', 'alpha'}
_aes_renames = {'linetype': 'linestyle'}

def plot(self, layer, ax):
x = layer.pop('x')
y1 = layer.pop('ymin')
y2 = layer.pop('ymax')
ax.fill_between(x, y1, y2, **layer)
def _plot_unit(self, pinfo, ax):
x = pinfo.pop('x')
y1 = pinfo.pop('ymin')
y2 = pinfo.pop('ymax')
ax.fill_between(x, y1, y2, **pinfo)

16 changes: 9 additions & 7 deletions ggplot/geoms/geom_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ class geom_bar(geom):
DEFAULT_PARAMS = {'stat': 'bin', 'position':'stack'}

_groups = {'color'}
def plot(self, layer, ax):
x = layer.pop('x')
if 'weight' not in layer:
_aes_renames = {'linetype': 'linestyle'}

def _plot_unit(self, pinfo, ax):
x = pinfo.pop('x')
if 'weight' not in pinfo:
counts = pd.value_counts(x)
labels = counts.index.tolist()
weights = counts.tolist()
else:
# TODO: pretty sure this isn't right
weights = layer.pop('weight')
weights = pinfo.pop('weight')
if not isinstance(x[0], Timestamp):
labels = x
else:
Expand All @@ -38,10 +40,10 @@ def plot(self, layer, ax):
labels, weights = np.array(labels)[idx], np.array(weights)[idx]
labels = sorted(labels)

layer['edgecolor'] = layer.pop('color', '#333333')
layer['color'] = layer.pop('fill', '#333333')
pinfo['edgecolor'] = pinfo.pop('color', '#333333')
pinfo['color'] = pinfo.pop('fill', '#333333')

ax.bar(indentation, weights, width, **layer)
ax.bar(indentation, weights, width, **pinfo)
ax.autoscale()
ax.set_xticks(indentation+width/2)
ax.set_xticklabels(labels)
15 changes: 8 additions & 7 deletions ggplot/geoms/geom_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ class geom_density(geom):
REQUIRED_AES = {'x'}
DEFAULT_PARAMS = {'stat': 'density', 'position': 'identity', 'label': ''}

_groups = {'color', 'linestyle', 'alpha'}
_groups = {'color', 'linetype', 'alpha'}
_aes_renames = {'linetype': 'linestyle'}

def plot(self, layer, ax):
x = layer.pop('x')
fill = layer.pop('fill', None)
layer['label'] = self.params['label']
def _plot_unit(self, pinfo, ax):
x = pinfo.pop('x')
fill = pinfo.pop('fill', None)
pinfo['label'] = self.params['label']

try:
float(x[0])
Expand All @@ -32,6 +33,6 @@ def plot(self, layer, ax):
step = (top - bottom) / 1000.0
x = np.arange(bottom, top, step)
y = kde.evaluate(x)
ax.plot(x, y, **layer)
ax.plot(x, y, **pinfo)
if fill:
ax.fill_between(x, y1=np.zeros(len(x)), y2=y, **layer)
ax.fill_between(x, y1=np.zeros(len(x)), y2=y, **pinfo)
22 changes: 12 additions & 10 deletions ggplot/geoms/geom_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,29 @@ class geom_histogram(geom):
DEFAULT_PARAMS = {'stat': 'bin', 'position': 'stack', 'label': ''}

_groups = {'color', 'alpha', 'shape'}
_aes_renames = {'linetype': 'linestyle'}

def __init__(self, *args, **kwargs):
super(geom_histogram, self).__init__(*args, **kwargs)
self._warning_printed = False

def plot(self, layer, ax):
layer['label'] = self.params['label']
def _plot_unit(self, pinfo, ax):
pinfo['label'] = self.params['label']

if 'binwidth' in layer:
binwidth = layer.pop('binwidth')
if 'binwidth' in pinfo:
binwidth = pinfo.pop('binwidth')
try:
binwidth = float(binwidth)
bottom = np.nanmin(layer['x'])
top = np.nanmax(layer['x'])
layer['bins'] = np.arange(bottom, top + binwidth, binwidth)
bottom = np.nanmin(pinfo['x'])
top = np.nanmax(pinfo['x'])
pinfo['bins'] = np.arange(bottom, top + binwidth, binwidth)
except:
pass
if 'bins' not in layer:
layer['bins'] = 30
if 'bins' not in pinfo:
pinfo['bins'] = 30
if not self._warning_printed:
sys.stderr.write("binwidth defaulted to range/30. " +
"Use 'binwidth = x' to adjust this.\n")
self._warning_printed = True

ax.hist(**layer)
ax.hist(**pinfo)
10 changes: 5 additions & 5 deletions ggplot/geoms/geom_hline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class geom_hline(geom):
DEFAULT_PARAMS = {'stat': 'hline', 'position': 'identity', 'show_guide': False,
'label': ''}

_groups = {'color', 'alpha', 'linestyle'}
_translations = {'size': 'linewidth'}
_groups = {'color', 'alpha', 'linetype'}
_aes_renames = {'size': 'linewidth', 'linetype': 'linestyle'}

def plot(self, layer, ax):
layer['label'] = self.params['label']
ax.axhline(**layer)
def _plot_unit(self, pinfo, ax):
pinfo['label'] = self.params['label']
ax.axhline(**pinfo)


28 changes: 14 additions & 14 deletions ggplot/geoms/geom_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,35 @@ class geom_line(geom):
REQUIRED_AES = {'x', 'y'}
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity', 'label': ''}

_groups = {'color', 'alpha', 'linestyle'}
_translations = {'size': 'linewidth'}
_groups = {'color', 'alpha', 'linetype'}
_aes_renames = {'size': 'linewidth', 'linetype': 'linestyle'}

def __init__(self, *args, **kwargs):
super(geom_line, self).__init__(*args, **kwargs)
self._warning_printed = False

def plot(self, layer, ax):
x = layer.pop('x')
y = layer.pop('y')
layer['label'] = self.params['label']
def _plot_unit(self, pinfo, ax):
x = pinfo.pop('x')
y = pinfo.pop('y')
pinfo['label'] = self.params['label']

if 'linewidth' in layer and isinstance(layer['linewidth'], list):
if 'linewidth' in pinfo and isinstance(pinfo['linewidth'], list):
# ggplot also supports aes(size=...) but the current mathplotlib
# is not. See https://github.com/matplotlib/matplotlib/issues/2658
layer['linewidth'] = 4
pinfo['linewidth'] = 4
if not self._warning_printed:
msg = "'geom_line()' currenty does not support the mapping of " +\
"size ('aes(size=<var>'), using size=4 as a replacement.\n" +\
"Use 'geom_line(size=x)' to set the size for the whole line.\n"
sys.stderr.write(msg)
self._warning_printed = True
if 'linestyle' in layer and 'color' not in layer:
layer['color'] = 'k'
if 'group' not in layer:
ax.plot(x, y, **layer)
if 'linestyle' in pinfo and 'color' not in pinfo:
pinfo['color'] = 'k'
if 'group' not in pinfo:
ax.plot(x, y, **pinfo)
else:
g = layer.pop('group')
g = pinfo.pop('group')
for k, v in groupby(sorted(zip(x, y, g), key=itemgetter(2)),
key=itemgetter(2)):
x_g, y_g, _ = zip(*v)
ax.plot(x_g, y_g, **layer)
ax.plot(x_g, y_g, **pinfo)
6 changes: 3 additions & 3 deletions ggplot/geoms/geom_now_its_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
class geom_now_its_art(geom):
VALID_AES = {'x', 'y'}

def plot(self, data, ax):
x = np.array(layer['x'])
y = np.array(layer['y'])
def _plot_unit(self, data, ax):
x = np.array(pinfo['x'])
y = np.array(pinfo['y'])

img = mpimg.imread(os.path.join(_ROOT, 'bird.png'))
# plt.imshow(img, alpha=0.5, extent=[x.min(), x.max(), y.min(), y.max()])
Expand Down
16 changes: 8 additions & 8 deletions ggplot/geoms/geom_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ class geom_point(geom):
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity', 'cmap':None, 'label': ''}

_groups = {'color', 'shape', 'alpha'}
_translations = {'size': 's', 'shape': 'marker'}
_aes_renames = {'size': 's', 'shape': 'marker'}

def plot(self, layer, ax):
layer['label'] = self.params['label']
def _plot_unit(self, pinfo, ax):
pinfo['label'] = self.params['label']
# for some reason, scatter doesn't default to the same color styles
# as the axes.color_cycle
if "color" not in layer and self.params['cmap'] is None:
layer["color"] = mpl.rcParams.get("axes.color_cycle", ["#333333"])[0]
if "color" not in pinfo and self.params['cmap'] is None:
pinfo["color"] = mpl.rcParams.get("axes.color_cycle", ["#333333"])[0]

if self.params['position'] == 'jitter':
layer['x'] *= np.random.uniform(.9, 1.1, len(layer['x']))
layer['y'] *= np.random.uniform(.9, 1.1, len(layer['y']))
pinfo['x'] *= np.random.uniform(.9, 1.1, len(pinfo['x']))
pinfo['y'] *= np.random.uniform(.9, 1.1, len(pinfo['y']))

ax.scatter(**layer)
ax.scatter(**pinfo)

Loading

0 comments on commit 63b2c61

Please sign in to comment.