diff --git a/backends/huggingface_local_api.py b/backends/huggingface_local_api.py index e9fac6abc831feaacd4e995bf795bba1ef8f8f82..678cc5ed094c07641451e3c1086f3f18e5e4067f 100644 --- a/backends/huggingface_local_api.py +++ b/backends/huggingface_local_api.py @@ -36,6 +36,8 @@ MODEL_DEEPSEEK_67B_CHAT = "deepseek-llm-67b-chat" MODEL_TULU_2_DPO_7B = "tulu-2-dpo-7b" MODEL_TULU_2_DPO_70B = "tulu-2-dpo-70b" MODEL_MIXTRAL_8X7B_INSTRUCT_V0_1 = "Mixtral-8x7B-Instruct-v0.1" +MODEL_SUS_CHAT_34B = "SUS-Chat-34B" + SUPPORTED_MODELS = [MODEL_MISTRAL_7B_INSTRUCT_V0_1, MODEL_RIIID_SHEEP_DUCK_LLAMA_2_70B_V1_1, MODEL_RIIID_SHEEP_DUCK_LLAMA_2_13B, MODEL_FALCON_7B_INSTRUCT, MODEL_OPEN_ASSISTANT_12B, @@ -43,7 +45,7 @@ SUPPORTED_MODELS = [MODEL_MISTRAL_7B_INSTRUCT_V0_1, MODEL_RIIID_SHEEP_DUCK_LLAMA MODEL_LMSYS_VICUNA_13B, MODEL_LMSYS_VICUNA_33B, MODEL_LMSYS_VICUNA_7B, MODEL_GPT4ALL_13B_SNOOZY, MODEL_CODELLAMA_34B_I, MODEL_ZEPHYR_7B_ALPHA, MODEL_ZEPHYR_7B_BETA, MODEL_OPENCHAT_3_5, MODEL_YI_34B_CHAT, MODEL_DEEPSEEK_7B_CHAT, MODEL_DEEPSEEK_67B_CHAT, MODEL_TULU_2_DPO_7B, - MODEL_TULU_2_DPO_70B, MODEL_MIXTRAL_8X7B_INSTRUCT_V0_1] + MODEL_TULU_2_DPO_70B, MODEL_MIXTRAL_8X7B_INSTRUCT_V0_1, MODEL_SUS_CHAT_34B] NAME = "huggingface" @@ -82,6 +84,9 @@ tulu_template = "{% for message in messages %}\n{% if message['role'] == 'user' DEEPSEEK = [MODEL_DEEPSEEK_7B_CHAT, MODEL_DEEPSEEK_67B_CHAT] # jinja template for deepseek format: deepseek_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = true %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" +SUSTECH = [MODEL_SUS_CHAT_34B] +sus_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Human: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Assistant: ' + message['content'] }}{% endif %}{% if loop.last %}{{ '### Assistant: ' }}{% endif %}{% endfor %}" + # templates currently have 'generation prompt' hardcoded # doesn't matter for clembench, but once added, templates can be pushed to HF and this block can be reduced @@ -89,7 +94,7 @@ deepseek_template = "{% if not add_generation_prompt is defined %}{% set add_gen # but transformers==4.34.0 does not support this feature (at least not reliably) # due to issues with differences between fast and slow HF tokenizer classes, some models require the 'slow' class/arg -SLOW_TOKENIZER = [MODEL_YI_34B_CHAT, MODEL_ORCA_2_13B] +SLOW_TOKENIZER = [MODEL_YI_34B_CHAT, MODEL_ORCA_2_13B, MODEL_SUS_CHAT_34B] class HuggingfaceLocal(backends.Backend): @@ -142,6 +147,8 @@ class HuggingfaceLocal(backends.Backend): hf_user_prefix = "deepseek-ai/" elif model_name in [MODEL_TULU_2_DPO_7B, MODEL_TULU_2_DPO_70B]: # allenai models hf_user_prefix = "allenai/" + elif model_name in [MODEL_SUS_CHAT_34B]: # SUSTech models + hf_user_prefix = "SUSTech/" hf_model_str = f"{hf_user_prefix}{model_name}" @@ -173,6 +180,9 @@ class HuggingfaceLocal(backends.Backend): self.tokenizer.chat_template = tulu_template elif model_name in DEEPSEEK: self.tokenizer.chat_template = deepseek_template + elif model_name in SUSTECH: + self.tokenizer.chat_template = sus_template + # load all models using their default configuration: self.model = AutoModelForCausalLM.from_pretrained(hf_model_str, device_map="auto", torch_dtype="auto", @@ -205,9 +215,17 @@ class HuggingfaceLocal(backends.Backend): logger.info(f"Finished loading huggingface model: {model}") logger.info(f"Model device map: {self.model.hf_device_map}") + # log current given messages list: + # logger.info(f"Raw messages passed: {messages}") + # deepcopy messages to prevent reference issues: current_messages = copy.deepcopy(messages) + # cull empty system message: + if current_messages[0]['role'] == "system": + if not current_messages[0]['content']: + del current_messages[0] + # flatten consecutive user messages: for msg_idx, message in enumerate(current_messages): if msg_idx > 0 and message['role'] == "user" and current_messages[msg_idx - 1]['role'] == "user": @@ -217,6 +235,9 @@ class HuggingfaceLocal(backends.Backend): current_messages[msg_idx - 1]['content'] += f" {message['content']}" del current_messages[msg_idx] + # log current flattened messages list: + # logger.info(f"Flattened messages: {current_messages}") + # apply chat template & tokenize: prompt_tokens = self.tokenizer.apply_chat_template(current_messages, return_tensors="pt") prompt_tokens = prompt_tokens.to(self.device) @@ -271,6 +292,9 @@ class HuggingfaceLocal(backends.Backend): # remove DeepSeek EOS token at the end of output: if response_text[-19:len(response_text)] == "<|endâ–ofâ–sentence|>": response_text = response_text[:-19] + # remove SUS EOS token at the end of output: + if response_text[-13:len(response_text)] == "<|endoftext|>": + response_text = response_text[:-13] else: response_text = model_output.strip()