Skip to content

Commit

Permalink
refactor how files are opened and handled
Browse files Browse the repository at this point in the history
  • Loading branch information
jminardi committed Dec 22, 2016
1 parent 2bba7cc commit e63ff8c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
82 changes: 44 additions & 38 deletions mecode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,7 @@ def __init__(self, outfile=None, print_lines=True, header=None, footer=None,
lineending insertion.
"""
# string file name
self.outfile = outfile if is_str(outfile) else None
# file descriptor
if not is_str(outfile) and outfile is not None:
# assume arg outfile is passed in a file descriptor
self.out_fd = outfile
else:
self.out_fd = None

self.outfile = outfile
self.print_lines = print_lines
self.header = header
self.footer = footer
Expand All @@ -172,12 +164,6 @@ def __init__(self, outfile=None, print_lines=True, header=None, footer=None,
self.x_axis = x_axis
self.y_axis = y_axis
self.z_axis = z_axis
if lineend == 'os':
self._open_as_binary = False
self.lineend = '\n'
else:
self._open_as_binary = True
self.lineend = lineend

self._current_position = defaultdict(float)
self.is_relative = True
Expand All @@ -195,6 +181,23 @@ def __init__(self, outfile=None, print_lines=True, header=None, footer=None,
self._socket = None
self._p = None

# If the user passes in a line ending then we need to open the output
# file in binary mode, otherwise python will try to be smart and
# convert line endings in a platform dependent way.
if lineend == 'os':
mode = 'w+'
self.lineend = '\n'
else:
mode = 'wb+'
self.lineend = lineend

if is_str(outfile):
self.out_fd = open(outfile, mode)
elif outfile is not None: # if outfile not str assume it is an open file
self.out_fd = outfile
else:
self.out_fd = None

if setup:
self.setup()

Expand Down Expand Up @@ -312,14 +315,10 @@ def teardown(self, wait=True):
if self.out_fd is not None:
if self.aerotech_include is True:
with open(os.path.join(HERE, 'footer.txt')) as fd:
lines = fd.readlines()
lines = [encode2To3(x.rstrip()+self.lineend) for x in lines]
self.out_fd.writelines(lines)
self._write_out(lines=fd.readlines())
if self.footer is not None:
with open(self.footer) as fd:
lines = fd.readlines()
lines = [encode2To3(x.rstrip()+self.lineend) for x in lines]
self.out_fd.writelines(lines)
self._write_out(lines=fd.readlines())
self.out_fd.close()
if self._socket is not None:
self._socket.close()
Expand Down Expand Up @@ -849,9 +848,8 @@ def view(self, backend='mayavi'):
def write(self, statement_in, resp_needed=False):
if self.print_lines:
print(statement_in)
self._write_out(statement_in)
statement = encode2To3(statement_in + self.lineend)
if self.out_fd is not None:
self.out_fd.write(statement)
if self.direct_write is True:
if self.direct_write_mode == 'socket':
if self._socket is None:
Expand Down Expand Up @@ -896,6 +894,23 @@ def rename_axis(self, x=None, y=None, z=None):

# Private Interface ######################################################

def _write_out(self, line=None, lines=None):
""" Writes given `line` or `lines` to the output file.
"""
# Only write if user requested an output file.
if self.out_fd is None:
return

if lines is not None:
for line in lines:
self._write_out(line)

line = line.rstrip() + self.lineend # add lineend character
if 'b' in self.out_fd.mode: # encode the string to binary if needed
line = encode2To3(line)
self.out_fd.write(line)


def _meander_passes(self, minor, spacing):
if minor > 0:
passes = math.ceil(minor / spacing)
Expand All @@ -907,21 +922,12 @@ def _meander_spacing(self, minor, spacing):
return minor / self._meander_passes(minor, spacing)

def _write_header(self):
outfile = self.outfile
if outfile is not None or self.out_fd is not None:
if self.out_fd is None: # open it if it is a path
mode = 'wb+' if self._open_as_binary else 'w+'
self.out_fd = open(outfile, mode)
if self.aerotech_include is True:
with open(os.path.join(HERE, 'header.txt')) as fd:
lines = fd.readlines()
lines = [encode2To3(x.rstrip()+self.lineend) for x in lines]
self.out_fd.writelines(lines)
if self.header is not None:
with open(self.header) as fd:
lines = fd.readlines()
lines = [encode2To3(x.rstrip()+self.lineend) for x in lines]
self.out_fd.writelines(lines)
if self.aerotech_include is True:
with open(os.path.join(HERE, 'header.txt')) as fd:
self._write_out(lines=fd.readlines())
if self.header is not None:
with open(self.header) as fd:
self._write_out(lines=fd.readlines())

def _format_args(self, x=None, y=None, z=None, **kwargs):
d = self.output_digits
Expand Down
21 changes: 18 additions & 3 deletions mecode/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def getGClass(self):
return G

def setUp(self):
self.outfile = TemporaryFile()
self.outfile = TemporaryFile('w+')
self.g = self.getGClass()(outfile=self.outfile, print_lines=False,
aerotech_include=False)
self.expected = ""
Expand All @@ -47,7 +47,9 @@ def assert_output(self):
self.expected = [x.strip() for x in self.expected if x.strip()]
self.outfile.seek(0)
lines = self.outfile.readlines()
lines = [decode2To3(x).strip() for x in lines if x.strip()]
if 'b' in self.outfile.mode:
lines = [decode2To3(x) for x in lines]
lines = [x.strip() for x in lines if x.strip()]
self.assertListEqual(lines, self.expected)
self.expected = string_rep

Expand Down Expand Up @@ -112,10 +114,13 @@ def test_dwell(self):
self.assert_output()

def test_setup(self):
self.outfile.close()
self.outfile = TemporaryFile()
self.g = G(outfile=self.outfile, print_lines=False)
self.expected = ""
self.expect_cmd(open(os.path.join(HERE, '../header.txt')).read())
with open(os.path.join(HERE, '../header.txt')) as f:
lines = f.read()
self.expect_cmd(lines)
self.expect_cmd('G91')
self.assert_output()

Expand Down Expand Up @@ -672,6 +677,16 @@ def test_output_digits(self):
""")
self.assert_output()

def test_open_in_binary(self):
outfile = TemporaryFile('wb+')
g = self.getGClass()(outfile=outfile, print_lines=False,
aerotech_include=False)
g.move(10,10)
outfile.seek(0)
lines = outfile.readlines()
assert(type(lines[0]) == bytes)
outfile.close()


if __name__ == '__main__':
unittest.main()

0 comments on commit e63ff8c

Please sign in to comment.