From 31110f41b568f64828471b7d2d4c1f9d68df4a23 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Tue, 15 Aug 2023 14:48:01 -0500 Subject: [PATCH] output_writer: fix bug introduced in avoiding shell invocation calling clear --- src/lmql/runtime/output_writer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/lmql/runtime/output_writer.py b/src/lmql/runtime/output_writer.py index 735da311..f778fa4f 100644 --- a/src/lmql/runtime/output_writer.py +++ b/src/lmql/runtime/output_writer.py @@ -57,14 +57,16 @@ def add_decoder_state(*args, **kwargs): async def add_interpreter_head_state(self, variable, head, prompt, where, trace, is_valid, is_final, mask, num_tokens, program_variables): if head == 0: if self.clear: - sys.stderr.write('\033c', flush=True) + sys.stderr.write('\033c') + sys.stderr.flush() if self.print_output: print(f"{prompt}\n\n valid={is_valid}, final={is_final}") def add_compiler_output(self, code): pass class StreamingOutputWriter: - def __init__(self, variable=None): + def __init__(self, variable=None, clear=True): + self.clear = clear self.variable = variable self.last_value = None @@ -87,8 +89,10 @@ async def add_interpreter_head_state(self, variable, head, prompt, where, trace, print(value[len(self.last_value):], end="", flush=True) self.last_value = value return - - sys.stderr.write('\033c', flush=True) # clear screen + + if clear: + sys.stderr.write('\033c') + sys.stderr.flush() print(f"{prompt}\n", end="\r") def add_compiler_output(self, code): pass