Skip to content

Commit

Permalink
add get_query_token_embedding()
Browse files Browse the repository at this point in the history
  • Loading branch information
Rui Zhang authored Aug 22, 2019
1 parent 085fcbc commit 25cf8b5
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion model/schema_interaction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,23 @@ def encode_schema_bow_simple(self, input_schema):
schema_states.append(input_schema.column_name_embedder_bow(column_name, surface_form=False, column_name_token_embedder=self.column_name_token_embedder))
input_schema.set_column_name_embeddings(schema_states)
return schema_states


def get_query_token_embedding(self, output_token, input_schema):
if input_schema:
# print('output_token', output_token)
# assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
# TODO
if not (self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)):
output_token = 'value'
if self.output_embedder.in_vocabulary(output_token):
output_token_embedding = self.output_embedder(output_token)
else:
# output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True, column_name_token_embedder=self.column_name_token_embedder)
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
else:
output_token_embedding = self.output_embedder(output_token)
return output_token_embedding

def train_step(self, interaction, max_generation_length, snippet_alignment_probability=1.):
""" Trains the interaction-level model on a single interaction.
Expand Down

0 comments on commit 25cf8b5

Please sign in to comment.