Skip to content

Retrieve

labridge.tools.instrument.retrieve

labridge.tools.instrument.retrieve.InstrumentRetrieverTool

Bases: RetrieverBaseTool

Source code in labridge/tools/instrument/retrieve.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class InstrumentRetrieverTool(RetrieverBaseTool):
	def __init__(
		self,
		llm: LLM = None,
		embed_model: BaseEmbedding = None,
		metadata_mode: MetadataMode = MetadataMode.NONE,
	):
		instrument_retriever = InstrumentRetriever(
			llm=llm,
			embed_model=embed_model,
		)
		self.metadata_mode = metadata_mode
		self.super_user_manager = InstrumentSuperUserManager()
		super().__init__(
			retriever=instrument_retriever,
			name=InstrumentRetrieverTool.__name__,
			retrieve_fn=InstrumentRetriever.retrieve
		)

	def log(self, log_dict: dict) -> ToolLog:
		ref_infos: List[InstrumentInfo] = log_dict[TOOL_LOG_REF_INFO_KEY]
		instrument_infos = [info.dumps() for info in ref_infos]

		log_to_user = None
		log_to_system = {
			TOOL_OP_DESCRIPTION: f"Use the {self.metadata.name} to retrieve the instrument docs.",
			TOOL_REFERENCES: instrument_infos,
		}

		return ToolLog(
			tool_name=self.metadata.name,
			log_to_user=log_to_user,
			log_to_system=log_to_system,
		)

	def _retrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		nodes = self._retriever.retrieve(**retrieve_kwargs)
		return nodes

	async def _aretrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		nodes = await self._retriever.aretrieve(**retrieve_kwargs)
		return nodes

	def get_ref_info(self, nodes: List[NodeWithScore]) -> List[RefInfoBase]:
		r""" Get the reference infos from the retrieved nodes. """
		instrument_infos = []
		instrument_set = set()
		for node in nodes:
			instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
			# TODO: Add node type and filter.
			if instrument_id == INSTRUMENT_ROOT_NODE_NAME or instrument_id in instrument_set:
				continue
			instrument_set.add(instrument_id)
			super_users = self.super_user_manager.get_super_users(
				instrument_id=instrument_id,
			)
			info = InstrumentInfo(
				instrument_id=instrument_id,
				super_users=super_users,
			)
			instrument_infos.append(info)
		return instrument_infos

	def _nodes_to_tool_output(self, nodes: List[NodeWithScore]) -> Tuple[str, dict]:
		r""" output the retrieved contents in a specific format. """
		output = ""
		header = f"Have retrieved the docs of several relevant instruments:\n\n"
		output += header

		if len(nodes) < 1:
			output += "No relevant instrument contents found."

		ref_infos = self.get_ref_info(nodes=nodes)
		log_dict = {
			TOOL_LOG_REF_INFO_KEY: ref_infos,
		}

		instrument_docs = dict()
		for node in nodes:
			instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
			# TODO: Add node type and filter.
			if instrument_id == INSTRUMENT_ROOT_NODE_NAME:
				continue
			if instrument_id not in instrument_docs:
				instrument_docs[instrument_id] = []
			instrument_docs[instrument_id].append(node)

		for instrument_id in instrument_docs.keys():
			instrument_content = f"Instrument Name: {instrument_id}\n"
			for idx, node in enumerate(instrument_docs[instrument_id]):
				instrument_content += (
					f"Retrieved content {idx + 1}:\n"
					f"{node.node.get_content(metadata_mode=self.metadata_mode)}\n"
			)
			output += f"{instrument_content}\n"
		return output, log_dict

labridge.tools.instrument.retrieve.InstrumentRetrieverTool.get_ref_info(nodes)

Get the reference infos from the retrieved nodes.

Source code in labridge/tools/instrument/retrieve.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def get_ref_info(self, nodes: List[NodeWithScore]) -> List[RefInfoBase]:
	r""" Get the reference infos from the retrieved nodes. """
	instrument_infos = []
	instrument_set = set()
	for node in nodes:
		instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
		# TODO: Add node type and filter.
		if instrument_id == INSTRUMENT_ROOT_NODE_NAME or instrument_id in instrument_set:
			continue
		instrument_set.add(instrument_id)
		super_users = self.super_user_manager.get_super_users(
			instrument_id=instrument_id,
		)
		info = InstrumentInfo(
			instrument_id=instrument_id,
			super_users=super_users,
		)
		instrument_infos.append(info)
	return instrument_infos