hi, I am using ktrain distilbert model for intent classification, i am trying to build custom graph components, from typing import Dict, Text, Any, List import os from typing import Dict, Any, Union from rasa.engine.graph import GraphComponent, ExecutionContext from rasa.engine.recipes.default_recipe import DefaultV1Recipe from rasa.engine.storage.resource import Resource from rasa.engine.storage.storage import ModelStorage from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData import ktrain
@DefaultV1Recipe.register( [DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER], is_trainable=True )
class KtrainIntentClassifier(GraphComponent):
def __init__(self, component_config: Dict[str, Any]):
super().__init__(component_config)
# load your pre-trained model here
self.model = ktrain.load_predictor("distil_bert/distilbert_model_40epochs")
def train(self, training_data, cfg, **kwargs) -> Dict[str, Any]:
# this component doesn't need to be trained as it's pre-trained
return {"model_file": "distil_bert/distilbert_model_40epochs"}
def process(self, message, **kwargs):
# use the pre-trained model to predict the intent of the message
intent = self.model.predict(message.text)[0]
message.set("intent", {"name": intent, "confidence": 1.0}, add_to_output=True)
@classmethod
def load(cls, model_dir=None, model_metadata=None, cached_component=None, **kwargs):
model_file = model_metadata.get("model_file")
component = cls(model_metadata)
component.model = ktrain.load_predictor(model_file)
return component
def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
pass