提交 85c52092 创建 作者: 宋海霞's avatar 宋海霞

init

上级
.vscode
*.pyc
output
model/*
data/*
*.om
*.pb
*.zip
*.tar.gz
*.png
*.jpg
*.mp4
*.wmv
*.pdf
*.mp3
.ipynb_checkpoints
.keep
*.bin
kernel_meta/
aclinit.json
*.whl
*.so
*.swp
*.lock
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ac73ef47-ecee-4293-8696-69513eb76495",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
" return self._float_to_str(self.smallest_subnormal)\n",
"/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
" return self._float_to_str(self.smallest_subnormal)\n"
]
}
],
"source": [
"import gradio as gr\n",
"import os\n",
"import mindspore\n",
"from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from mindnlp.transformers import TextIteratorStreamer\n",
"from threading import Thread"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3b627d06-c502-4bfc-907b-79f59f499861",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Qwen2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
" - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
" - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n",
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
]
}
],
"source": [
"SRC_PATH = os.getcwd()\n",
"MODEL_PATH = os.path.join(SRC_PATH, \"./model/MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, mirror=\"modelers\", ms_dtype=mindspore.float16)\n",
"model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, mirror=\"modelers\", ms_dtype=mindspore.float16)\n",
"\n",
"system_prompt = \"You are a helpful and friendly chatbot\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bfefb17a-8cb1-4333-ae8d-ab5f495c1d9d",
"metadata": {},
"outputs": [],
"source": [
"def build_input_from_chat_history(chat_history, msg: str):\n",
" messages = [{'role': 'system', 'content': system_prompt}]\n",
" for user_msg, ai_msg in chat_history:\n",
" messages.append({'role': 'user', 'content': user_msg})\n",
" messages.append({'role': 'assistant', 'content': ai_msg})\n",
" messages.append({'role': 'user', 'content': msg})\n",
" return messages"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5b75b073-e251-4b7c-bc60-fc2b4fc931ae",
"metadata": {},
"outputs": [],
"source": [
"def predict(message, history):\n",
" \n",
"\n",
" # Formatting the input for the model.\n",
" messages = build_input_from_chat_history(history, message)\n",
" input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"ms\",\n",
" tokenize=True\n",
" )\n",
" streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)\n",
" generate_kwargs = dict(\n",
" input_ids=input_ids,\n",
" streamer=streamer,\n",
" max_new_tokens=1024,\n",
" do_sample=True,\n",
" top_p=0.9,\n",
" temperature=0.1,\n",
" num_beams=1,\n",
" repetition_penalty=1.2\n",
" )\n",
" t = Thread(target=model.generate, kwargs=generate_kwargs)\n",
" t.start() # Starting the generation in a separate thread.\n",
" partial_message = \"\"\n",
" for new_token in streamer:\n",
" partial_message += new_token\n",
" if '</s>' in partial_message: # Breaking the loop if the stop token is generated.\n",
" break\n",
" yield partial_message"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "689f36be-c14b-495e-9315-80c7b09773b7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/miniconda3/lib/python3.9/site-packages/gradio/analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade. \n",
"--------\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://192.168.110.83:7862\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://192.168.110.83:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gr.ChatInterface(predict,\n",
" title=\"DeepSeek-R1-Distill-Qwen-1.5B\",\n",
" description=\"问几个问题\",\n",
" examples=['你是谁?', '你能做什么?']\n",
" ).launch(server_name=\"192.168.110.83\", server_port=7862) # Launching the web interface."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论