Spaces:
Running
Running
| from collections import Counter | |
| import graphviz | |
| import penman | |
| from penman.models.noop import NoOpModel | |
| from mbart_amr.data.linearization import linearized2penmanstr | |
| from transformers import LogitsProcessorList | |
| import streamlit as st | |
| from utils import get_resources, LANGUAGES, translate | |
| st.title("π©βπ» Generate AMR from multilingual text") | |
| with st.form("input data"): | |
| text_col, lang_col = st.columns((4, 1)) | |
| text = text_col.text_input(label="Input text") | |
| src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| multilingual = src_lang != "English" | |
| model, tokenizer, logitsprocessor = get_resources(multilingual) | |
| gen_kwargs = { | |
| "max_length": model.config.max_length, | |
| "num_beams": model.config.num_beams, | |
| "logits_processor": LogitsProcessorList([logitsprocessor]) | |
| } | |
| linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs) | |
| penman_str = linearized2penmanstr(linearized) | |
| try: | |
| graph = penman.decode(penman_str, model=NoOpModel()) | |
| except Exception as exc: | |
| st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" | |
| f" to a valid graph but note that this is invalid Penman.") | |
| st.code(penman_str) | |
| with st.expander("Error trace"): | |
| st.write(exc) | |
| else: | |
| visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box", | |
| "fontcolor": "white"}) | |
| # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01 | |
| nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"]) | |
| # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"} | |
| nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"} | |
| # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"} | |
| # but only the value occurs more than once | |
| nodename_str_c = Counter() | |
| for varname in nodenames: | |
| nodename = nodenames[varname] | |
| if nodename_c[nodename] > 1: | |
| nodename_str_c[nodename] += 1 | |
| nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})" | |
| def get_node_name(item: str): | |
| return nodenames[item] if item in nodenames else item | |
| try: | |
| for triple in graph.triples: | |
| if triple[1] == ":instance": | |
| continue | |
| else: | |
| visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1]) | |
| except Exception as exc: | |
| st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" | |
| " to a valid graph but note that this is probably invalid Penman.") | |
| st.code(penman_str) | |
| st.write("The initial linearized output of the model was:") | |
| st.code(linearized) | |
| with st.expander("Error trace"): | |
| st.write(exc) | |
| else: | |
| st.subheader("Graph visualization") | |
| st.graphviz_chart(visualized, use_container_width=True) | |
| # Download | |
| img = visualized.pipe(format="png") | |
| st.download_button("Download graph", img, mime="image/png") | |
| # Additional info | |
| st.subheader("Model output and Penman graph") | |
| st.write("The linearized output of the model (after some post-processing) is:") | |
| st.code(linearized) | |
| st.write("When converted into Penman, it looks like this:") | |
| st.code(penman.encode(graph)) | |
| ######################## | |
| # Information, socials # | |
| ######################## | |
| st.markdown("## Contact βοΈ") | |
| st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?" | |
| " Give me a shout on [Twitter](https://twitter.com/BramVanroy)" | |
| " or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!") | |