pip install "tensorflow-text==2.8.*"
import collections
import pathlib
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import utils
from tensorflow.keras.layers import TextVectorization
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
示例 1:预测 Stack Overflow 问题的标签
作为第一个示例,您将从 Stack Overflow 下载一个编程问题的数据集。每个问题(“How do I sort a dictionary by value?”)都会添加一个标签(Python
或 Java
首先,使用 tf.keras.utils.get_file
下载 Stack Overflow 数据集,然后探索目录结构:
data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'
dataset_dir = utils.get_file(
dataset_dir = pathlib.Path(dataset_dir).parent
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz 6053888/6053168 [==============================] - 0s 0us/step 6062080/6053168 [==============================] - 0s 0us/step
[PosixPath('/tmp/.keras/train'), PosixPath('/tmp/.keras/stack_overflow_16k.tar.gz'), PosixPath('/tmp/.keras/test'), PosixPath('/tmp/.keras/README.md')]
train_dir = dataset_dir/'train'
[PosixPath('/tmp/.keras/train/python'), PosixPath('/tmp/.keras/train/javascript'), PosixPath('/tmp/.keras/train/csharp'), PosixPath('/tmp/.keras/train/java')]
和 train/javascript
目录包含许多文本文件,每个文件都是一个 Stack Overflow 问题。
sample_file = train_dir/'python/1755.txt'
with open(sample_file) as f:
why does this blank program print true x=true.def stupid():. x=false.stupid().print x
接下来,您将从磁盘加载数据并将其准备成适合训练的格式。为此,您将使用 tf.keras.utils.text_dataset_from_directory
效用函数来创建带标签的 tf.data.Dataset
。如果您是 tf.data
新手,它是用于构建输入流水线的强大工具集合。(要了解更多信息,请参阅 tf.data:构建 TensorFlow 输入流水线指南。)
API 需要如下目录结构:
Stack Overflow 数据集已经拆分为训练集和测试集,但缺少验证集。
通过使用 tf.keras.utils.text_dataset_from_directory
并将 validation_split
设置为 0.2
(即 20%),使用训练数据的 80:20 拆分创建验证集:
batch_size = 32
seed = 42
raw_train_ds = utils.text_dataset_from_directory(
Found 8000 files belonging to 4 classes. Using 6400 files for training.
正如前面的单元输出所示,训练文件夹中有 8,000 个样本,您将使用其中的 80%(即 6,400 个)进行训练。稍后您将学习到,可以通过将 tf.data.Dataset
直接传递给 Model.fit
注:为了增加分类问题的难度,数据集作者将编程问题中出现的单词 Python、CSharp、JavaScript 或 Java 替换为 blank 一词。
for text_batch, label_batch in raw_train_ds.take(1):
for i in range(10):
print("Question: ", text_batch.numpy()[i])
print("Label:", label_batch.numpy()[i])
Question: b'"my tester is going to the wrong constructor i am new to programming so if i ask a question that can be easily fixed, please forgive me. my program has a tester class with a main. when i send that to my regularpolygon class, it sends it to the wrong constructor. i have two constructors. 1 without perameters..public regularpolygon(). {. mynumsides = 5;. mysidelength = 30;. }//end default constructor...and my second, with perameters. ..public regularpolygon(int numsides, double sidelength). {. mynumsides = numsides;. mysidelength = sidelength;. }// end constructor...in my tester class i have these two lines:..regularpolygon shape = new regularpolygon(numsides, sidelength);. shape.menu();...numsides and sidelength were declared and initialized earlier in the testing class...so what i want to happen, is the tester class sends numsides and sidelength to the second constructor and use it in that class. but it only uses the default constructor, which therefor ruins the whole rest of the program. can somebody help me?..for those of you who want to see more of my code: here you go..public double vertexangle(). {. system.out.println(""the vertex angle method: "" + mynumsides);// prints out 5. system.out.println(""the vertex angle method: "" + mysidelength); // prints out 30.. double vertexangle;. vertexangle = ((mynumsides - 2.0) / mynumsides) * 180.0;. return vertexangle;. }//end method vertexangle..public void menu().{. system.out.println(mynumsides); // prints out what the user puts in. system.out.println(mysidelength); // prints out what the user puts in. gotographic();. calcr(mynumsides, mysidelength);. calcr(mynumsides, mysidelength);. print(); .}// end menu...this is my entire tester class:..public static void main(string[] arg).{. int numsides;. double sidelength;. scanner keyboard = new scanner(system.in);.. system.out.println(""welcome to the regular polygon program!"");. system.out.println();.. system.out.print(""enter the number of sides of the polygon ==> "");. numsides = keyboard.nextint();. system.out.println();.. system.out.print(""enter the side length of each side ==> "");. sidelength = keyboard.nextdouble();. system.out.println();.. regularpolygon shape = new regularpolygon(numsides, sidelength);. shape.menu();.}//end main...for testing it i sent it numsides 4 and sidelength 100."\n' Label: 1 Question: b'"blank code slow skin detection this code changes the color space to lab and using a threshold finds the skin area of an image. but it\'s ridiculously slow. i don\'t know how to make it faster ? ..from colormath.color_objects import *..def skindetection(img, treshold=80, color=[255,20,147]):.. print img.shape. res=img.copy(). for x in range(img.shape[0]):. for y in range(img.shape[1]):. rgbimg=rgbcolor(img[x,y,0],img[x,y,1],img[x,y,2]). labimg=rgbimg.convert_to(\'lab\', debug=false). if (labimg.lab_l > treshold):. res[x,y,:]=color. else: . res[x,y,:]=img[x,y,:].. return res"\n' Label: 3 Question: b'"option and validation in blank i want to add a new option on my system where i want to add two text files, both rental.txt and customer.txt. inside each text are id numbers of the customer, the videotape they need and the price...i want to place it as an option on my code. right now i have:...add customer.rent return.view list.search.exit...i want to add this as my sixth option. say for example i ordered a video, it would display the price and would let me confirm the price and if i am going to buy it or not...here is my current code:.. import blank.io.*;. import blank.util.arraylist;. import static blank.lang.system.out;.. public class rentalsystem{. static bufferedreader input = new bufferedreader(new inputstreamreader(system.in));. static file file = new file(""file.txt"");. static arraylist<string> list = new arraylist<string>();. static int rows;.. public static void main(string[] args) throws exception{. introduction();. system.out.print(""nn"");. login();. system.out.print(""nnnnnnnnnnnnnnnnnnnnnn"");. introduction();. string repeat;. do{. loadfile();. system.out.print(""nwhat do you want to do?nn"");. system.out.print(""n - - - - - - - - - - - - - - - - - - - - - - -"");. system.out.print(""nn | 1. add customer | 2. rent return |n"");. system.out.print(""n - - - - - - - - - - - - - - - - - - - - - - -"");. system.out.print(""nn | 3. view list | 4. search |n"");. system.out.print(""n - - - - - - - - - - - - - - - - - - - - - - -"");. system.out.print(""nn | 5. exit |n"");. system.out.print(""n - - - - - - - - - -"");. system.out.print(""nnchoice:"");. int choice = integer.parseint(input.readline());. switch(choice){. case 1:. writedata();. break;. case 2:. rentdata();. break;. case 3:. viewlist();. break;. case 4:. search();. break;. case 5:. system.out.println(""goodbye!"");. system.exit(0);. default:. system.out.print(""invalid choice: "");. break;. }. system.out.print(""ndo another task? [y/n] "");. repeat = input.readline();. }while(repeat.equals(""y""));.. if(repeat!=""y"") system.out.println(""ngoodbye!"");.. }.. public static void writedata() throws exception{. system.out.print(""nname: "");. string cname = input.readline();. system.out.print(""address: "");. string add = input.readline();. system.out.print(""phone no.: "");. string pno = input.readline();. system.out.print(""rental amount: "");. string ramount = input.readline();. system.out.print(""tapenumber: "");. string tno = input.readline();. system.out.print(""title: "");. string title = input.readline();. system.out.print(""date borrowed: "");. string dborrowed = input.readline();. system.out.print(""due date: "");. string ddate = input.readline();. createline(cname, add, pno, ramount,tno, title, dborrowed, ddate);. rentdata();. }.. public static void createline(string name, string address, string phone , string rental, string tapenumber, string title, string borrowed, string due) throws exception{. filewriter fw = new filewriter(file, true);. fw.write(""nname: ""+name + ""naddress: "" + address +""nphone no.: ""+ phone+""nrentalamount: ""+rental+""ntape no.: ""+ tapenumber+""ntitle: ""+ title+""ndate borrowed: ""+borrowed +""ndue date: ""+ due+"":rn"");. fw.close();. }.. public static void loadfile() throws exception{. try{. list.clear();. fileinputstream fstream = new fileinputstream(file);. bufferedreader br = new bufferedreader(new inputstreamreader(fstream));. rows = 0;. while( br.ready()). {. list.add(br.readline());. rows++;. }. br.close();. } catch(exception e){. system.out.println(""list not yet loaded."");. }. }.. public static void viewlist(){. system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. system.out.print("" |list of all costumers|"");. system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. for(int i = 0; i <rows; i++){. system.out.println(list.get(i));. }. }. public static void rentdata()throws exception. { system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. system.out.print("" |rent data list|"");. system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. system.out.print(""nenter customer name: "");. string cname = input.readline();. system.out.print(""date borrowed: "");. string dborrowed = input.readline();. system.out.print(""due date: "");. string ddate = input.readline();. system.out.print(""return date: "");. string rdate = input.readline();. system.out.print(""rent amount: "");. string ramount = input.readline();.. system.out.print(""you pay:""+ramount);... }. public static void search()throws exception. { system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. system.out.print("" |search costumers|"");. system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");. system.out.print(""nenter costumer name: "");. string cname = input.readline();. boolean found = false;.. for(int i=0; i < rows; i++){. string temp[] = list.get(i).split("","");.. if(cname.equals(temp[0])){. system.out.println(""search result:nyou are "" + temp[0] + "" from "" + temp[1] + "".""+ temp[2] + "".""+ temp[3] + "".""+ temp[4] + "".""+ temp[5] + "" is "" + temp[6] + "".""+ temp[7] + "" is "" + temp[8] + ""."");. found = true;. }. }.. if(!found){. system.out.print(""no results."");. }.. }.. public static boolean evaluate(string uname, string pass){. if (uname.equals(""admin"")&&pass.equals(""12345"")) return true;. else return false;. }.. public static string login()throws exception{. bufferedreader input=new bufferedreader(new inputstreamreader(system.in));. int counter=0;. do{. system.out.print(""username:"");. string uname =input.readline();. system.out.print(""password:"");. string pass =input.readline();.. boolean accept= evaluate(uname,pass);.. if(accept){. break;. }else{. system.out.println(""incorrect username or password!"");. counter ++;. }. }while(counter<3);.. if(counter !=3) return ""login successful"";. else return ""login failed"";. }. public static void introduction() throws exception{.. system.out.println("" - - - - - - - - - - - - - - - - - - - - - - - - -"");. system.out.println("" ! r e n t a l !"");. system.out.println("" ! ~ ~ ~ ~ ~ ! ================= ! ~ ~ ~ ~ ~ !"");. system.out.println("" ! s y s t e m !"");. system.out.println("" - - - - - - - - - - - - - - - - - - - - - - - - -"");. }..}"\n' Label: 1 Question: b'"exception: dynamic sql generation for the updatecommand is not supported against a selectcommand that does not return any key i dont know what is the problem this my code : ..string nomtable;..datatable listeetablissementtable = new datatable();.datatable listeinteretstable = new datatable();.dataset ds = new dataset();.sqldataadapter da;.sqlcommandbuilder cmdb;..private void listeinterets_click(object sender, eventargs e).{. nomtable = ""listeinteretstable"";. d.cnx.open();. da = new sqldataadapter(""select nome from offices"", d.cnx);. ds = new dataset();. da.fill(ds, nomtable);. datagridview1.datasource = ds.tables[nomtable];.}..private void sauvgarder_click(object sender, eventargs e).{. d.cnx.open();. cmdb = new sqlcommandbuilder(da);. da.update(ds, nomtable);. d.cnx.close();.}"\n' Label: 0 Question: b'"parameter with question mark and super in blank, i\'ve come across a method that is formatted like this:..public final subscription subscribe(final action1<? super t> onnext, final action1<throwable> onerror) {.}...in the first parameter, what does the question mark and super mean?"\n' Label: 1 Question: b'call two objects wsdl the first time i got a very strange wsdl. ..i would like to call the object (interface - invoicecheck_out) do you know how?....i would like to call the object (variable) do you know how?..try to call (it`s ok)....try to call (how call this?)\n' Label: 0 Question: b"how to correctly make the icon for systemtray in blank using icon sizes of any dimension for systemtray doesn't look good overall. .what is the correct way of making icons for windows system tray?..screenshots: http://imgur.com/zsibwn9..icon: http://imgur.com/vsh4zo8\n" Label: 0 Question: b'"is there a way to check a variable that exists in a different script than the original one? i\'m trying to check if a variable, which was previously set to true in 2.py in 1.py, as 1.py is only supposed to continue if the variable is true...2.py..import os..completed = false..#some stuff here..completed = true...1.py..import 2 ..if completed == true. #do things...however i get a syntax error at ..if completed == true"\n' Label: 3 Question: b'"blank control flow i made a number which asks for 2 numbers with blank and responds with the corresponding message for the case. how come it doesnt work for the second number ? .regardless what i enter for the second number , i am getting the message ""your number is in the range 0-10""...using system;.using system.collections.generic;.using system.linq;.using system.text;..namespace consoleapplication1.{. class program. {. static void main(string[] args). {. string myinput; // declaring the type of the variables. int myint;.. string number1;. int number;... console.writeline(""enter a number"");. myinput = console.readline(); //muyinput is a string which is entry input. myint = int32.parse(myinput); // myint converts the string into an integer.. if (myint > 0). console.writeline(""your number {0} is greater than zero."", myint);. else if (myint < 0). console.writeline(""your number {0} is less than zero."", myint);. else. console.writeline(""your number {0} is equal zero."", myint);.. console.writeline(""enter another number"");. number1 = console.readline(); . number = int32.parse(myinput); .. if (number < 0 || number == 0). console.writeline(""your number {0} is less than zero or equal zero."", number);. else if (number > 0 && number <= 10). console.writeline(""your number {0} is in the range from 0 to 10."", number);. else. console.writeline(""your number {0} is greater than 10."", number);.. console.writeline(""enter another number"");.. }. } .}"\n' Label: 0 Question: b'"credentials cannot be used for ntlm authentication i am getting org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials cannot be used for ntlm authentication: exception in eclipse..whether it is possible mention eclipse to take system proxy settings directly?..public class httpgetproxy {. private static final string proxy_host = ""proxy.****.com"";. private static final int proxy_port = 6050;.. public static void main(string[] args) {. httpclient client = new httpclient();. httpmethod method = new getmethod(""https://kodeblank.org"");.. hostconfiguration config = client.gethostconfiguration();. config.setproxy(proxy_host, proxy_port);.. string username = ""*****"";. string password = ""*****"";. credentials credentials = new usernamepasswordcredentials(username, password);. authscope authscope = new authscope(proxy_host, proxy_port);.. client.getstate().setproxycredentials(authscope, credentials);.. try {. client.executemethod(method);.. if (method.getstatuscode() == httpstatus.sc_ok) {. string response = method.getresponsebodyasstring();. system.out.println(""response = "" + response);. }. } catch (ioexception e) {. e.printstacktrace();. } finally {. method.releaseconnection();. }. }.}...exception:... dec 08, 2017 1:41:39 pm . org.apache.commons.httpclient.auth.authchallengeprocessor selectauthscheme. info: ntlm authentication scheme selected. dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector executeconnect. severe: credentials cannot be used for ntlm authentication: . org.apache.commons.httpclient.usernamepasswordcredentials. org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials . cannot be used for ntlm authentication: . enter code here . org.apache.commons.httpclient.usernamepasswordcredentials. at org.apache.commons.httpclient.auth.ntlmscheme.authenticate(ntlmscheme.blank:332). at org.apache.commons.httpclient.httpmethoddirector.authenticateproxy(httpmethoddirector.blank:320). at org.apache.commons.httpclient.httpmethoddirector.executeconnect(httpmethoddirector.blank:491). at org.apache.commons.httpclient.httpmethoddirector.executewithretry(httpmethoddirector.blank:391). at org.apache.commons.httpclient.httpmethoddirector.executemethod(httpmethoddirector.blank:171). at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:397). at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:323). at httpgetproxy.main(httpgetproxy.blank:31). dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector processproxyauthchallenge. info: failure authenticating with ntlm @proxy.****.com:6050"\n' Label: 1
标签为 0
或 3
。要查看其中哪些对应于哪个字符串标签,可以检查数据集上的 class_names
for i, label in enumerate(raw_train_ds.class_names):
print("Label", i, "corresponds to", label)
Label 0 corresponds to csharp Label 1 corresponds to java Label 2 corresponds to javascript Label 3 corresponds to python
接下来,您将使用 tf.keras.utils.text_dataset_from_directory
创建验证集和测试集。您将使用训练集中剩余的 1,600 条评论进行验证。
注:使用 tf.keras.utils.text_dataset_from_directory
的 validation_split
和 subset
参数时,请确保要么指定随机种子,要么传递 shuffle=False
# Create a validation set.
raw_val_ds = utils.text_dataset_from_directory(
Found 8000 files belonging to 4 classes. Using 1600 files for validation.
test_dir = dataset_dir/'test'
# Create a test set.
raw_test_ds = utils.text_dataset_from_directory(
Found 8000 files belonging to 4 classes.
接下来,您将使用 tf.keras.layers.TextVectorization
- 标准化是指预处理文本,通常是移除标点符号或 HTML 元素以简化数据集。
- 词例化是指将字符串拆分为词例(例如,通过按空格分割将一个句子拆分为各个单词)。
- 向量化是指将词例转换为编号,以便将它们输入到神经网络中。
所有这些任务都可以通过这一层来完成。(您可以在 tf.keras.layers.TextVectorization
API 文档中了解有关这些内容的更多信息。)
- 默认标准化会将文本转换为小写并移除标点符号 (
)。 - 默认分词器会按空格分割 (
)。 - 默认向量化模式为
您将使用 TextVectorization
- 首先,您将使用
向量化模式来构建词袋模型。 - 随后,您将使用具有 1D ConvNet 的
VOCAB_SIZE = 10000
binary_vectorize_layer = TextVectorization(
对于 'int'
模式,除了最大词汇量之外,您还需要设置显式最大序列长度 (MAX_SEQUENCE_LENGTH
),这会导致层将序列精确地填充或截断为 output_sequence_length
int_vectorize_layer = TextVectorization(
接下来,调用 TextVectorization.adapt
注:在调用 TextVectorization.adapt
# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.
train_text = raw_train_ds.map(lambda text, labels: text)
def binary_vectorize_text(text, label):
text = tf.expand_dims(text, -1)
return binary_vectorize_layer(text), label
def int_vectorize_text(text, label):
text = tf.expand_dims(text, -1)
return int_vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset.
text_batch, label_batch = next(iter(raw_train_ds))
first_question, first_label = text_batch[0], label_batch[0]
print("Question", first_question)
print("Label", first_label)
Question tf.Tensor(b'"what is the difference between these two ways to create an element? var a = document.createelement(\'div\');..a.id = ""mydiv"";...and..var a = document.createelement(\'div\').id = ""mydiv"";...what is the difference between them such that the first one works and the second one doesn\'t?"\n', shape=(), dtype=string) Label tf.Tensor(2, shape=(), dtype=int32)
print("'binary' vectorized question:",
binary_vectorize_text(first_question, first_label)[0])
'binary' vectorized question: tf.Tensor([[1. 1. 0. ... 0. 0. 0.]], shape=(1, 10000), dtype=float32)
print("'int' vectorized question:",
int_vectorize_text(first_question, first_label)[0])
'int' vectorized question: tf.Tensor( [[ 55 6 2 410 211 229 121 895 4 124 32 245 43 5 1 1 5 1 1 6 2 410 211 191 318 14 2 98 71 188 8 2 199 71 178 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]], shape=(1, 250), dtype=int64)
的 'binary'
模式返回一个数组,表示哪些词例在输入中至少存在一次,而 'int'
您可以通过在层上调用 TextVectorization.get_vocabulary
print("1289 ---> ", int_vectorize_layer.get_vocabulary()[1289])
print("313 ---> ", int_vectorize_layer.get_vocabulary()[313])
print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))
1289 ---> roman 313 ---> source Vocabulary size: 10000
作为最后的预处理步骤,将之前创建的 TextVectorization
binary_train_ds = raw_train_ds.map(binary_vectorize_text)
binary_val_ds = raw_val_ds.map(binary_vectorize_text)
binary_test_ds = raw_test_ds.map(binary_vectorize_text)
int_train_ds = raw_train_ds.map(int_vectorize_text)
int_val_ds = raw_val_ds.map(int_vectorize_text)
int_test_ds = raw_test_ds.map(int_vectorize_text)
以下是加载数据时应该使用的两种重要方法,以确保 I/O 不会阻塞。
- 从磁盘加载后,
会将数据保存在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存,也可以使用此方法创建高性能的磁盘缓存,这比许多小文件的读取效率更高。 Dataset.prefetch
您可以在使用 tf.data API 提升性能指南的预提取部分中详细了解这两种方法,以及如何将数据缓存到磁盘。
def configure_dataset(dataset):
return dataset.cache().prefetch(buffer_size=AUTOTUNE)
binary_train_ds = configure_dataset(binary_train_ds)
binary_val_ds = configure_dataset(binary_val_ds)
binary_test_ds = configure_dataset(binary_test_ds)
int_train_ds = configure_dataset(int_train_ds)
int_val_ds = configure_dataset(int_val_ds)
int_test_ds = configure_dataset(int_test_ds)
对于 'binary'
binary_model = tf.keras.Sequential([layers.Dense(4)])
history = binary_model.fit(
binary_train_ds, validation_data=binary_val_ds, epochs=10)
Epoch 1/10 200/200 [==============================] - 1s 4ms/step - loss: 1.1207 - accuracy: 0.6450 - val_loss: 0.9159 - val_accuracy: 0.7750 Epoch 2/10 200/200 [==============================] - 1s 3ms/step - loss: 0.7795 - accuracy: 0.8163 - val_loss: 0.7511 - val_accuracy: 0.8006 Epoch 3/10 200/200 [==============================] - 1s 3ms/step - loss: 0.6278 - accuracy: 0.8619 - val_loss: 0.6652 - val_accuracy: 0.8169 Epoch 4/10 200/200 [==============================] - 1s 3ms/step - loss: 0.5341 - accuracy: 0.8869 - val_loss: 0.6117 - val_accuracy: 0.8294 Epoch 5/10 200/200 [==============================] - 1s 3ms/step - loss: 0.4680 - accuracy: 0.9041 - val_loss: 0.5751 - val_accuracy: 0.8369 Epoch 6/10 200/200 [==============================] - 1s 3ms/step - loss: 0.4177 - accuracy: 0.9181 - val_loss: 0.5485 - val_accuracy: 0.8413 Epoch 7/10 200/200 [==============================] - 1s 3ms/step - loss: 0.3775 - accuracy: 0.9292 - val_loss: 0.5284 - val_accuracy: 0.8400 Epoch 8/10 200/200 [==============================] - 1s 3ms/step - loss: 0.3442 - accuracy: 0.9373 - val_loss: 0.5129 - val_accuracy: 0.8425 Epoch 9/10 200/200 [==============================] - 1s 3ms/step - loss: 0.3160 - accuracy: 0.9425 - val_loss: 0.5007 - val_accuracy: 0.8444 Epoch 10/10 200/200 [==============================] - 1s 3ms/step - loss: 0.2917 - accuracy: 0.9497 - val_loss: 0.4910 - val_accuracy: 0.8419
接下来,您将使用 'int'
向量化层来构建 1D ConvNet:
def create_model(vocab_size, num_labels):
model = tf.keras.Sequential([
layers.Embedding(vocab_size, 64, mask_zero=True),
layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
return model
# `vocab_size` is `VOCAB_SIZE + 1` since `0` is used additionally for padding.
int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)
history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)
Epoch 1/5 200/200 [==============================] - 2s 5ms/step - loss: 1.1306 - accuracy: 0.4970 - val_loss: 0.7709 - val_accuracy: 0.6919 Epoch 2/5 200/200 [==============================] - 1s 4ms/step - loss: 0.6226 - accuracy: 0.7592 - val_loss: 0.5376 - val_accuracy: 0.7969 Epoch 3/5 200/200 [==============================] - 1s 4ms/step - loss: 0.3735 - accuracy: 0.8863 - val_loss: 0.4675 - val_accuracy: 0.8163 Epoch 4/5 200/200 [==============================] - 1s 4ms/step - loss: 0.2061 - accuracy: 0.9503 - val_loss: 0.4679 - val_accuracy: 0.8163 Epoch 5/5 200/200 [==============================] - 1s 4ms/step - loss: 0.1014 - accuracy: 0.9820 - val_loss: 0.4937 - val_accuracy: 0.8119
print("Linear model on binary vectorized data:")
Linear model on binary vectorized data: Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 4) 40004 ================================================================= Total params: 40,004 Trainable params: 40,004 Non-trainable params: 0 _________________________________________________________________ None
print("ConvNet model on int vectorized data:")
ConvNet model on int vectorized data: Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) (None, None, 64) 640064 conv1d (Conv1D) (None, None, 64) 20544 global_max_pooling1d (Globa (None, 64) 0 lMaxPooling1D) dense_1 (Dense) (None, 4) 260 ================================================================= Total params: 660,868 Trainable params: 660,868 Non-trainable params: 0 _________________________________________________________________ None
binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)
int_loss, int_accuracy = int_model.evaluate(int_test_ds)
print("Binary model accuracy: {:2.2%}".format(binary_accuracy))
print("Int model accuracy: {:2.2%}".format(int_accuracy))
250/250 [==============================] - 1s 3ms/step - loss: 0.5179 - accuracy: 0.8163 250/250 [==============================] - 1s 2ms/step - loss: 0.5309 - accuracy: 0.8070 Binary model accuracy: 81.63% Int model accuracy: 80.70%
在上面的代码中,您在向模型馈送文本之前对数据集应用了 tf.keras.layers.TextVectorization
。如果您想让模型能够处理原始字符串(例如,为了简化部署),您可以在模型中包含 TextVectorization
export_model = tf.keras.Sequential(
[binary_vectorize_layer, binary_model,
# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print("Accuracy: {:2.2%}".format(binary_accuracy))
250/250 [==============================] - 1s 4ms/step - loss: 0.5179 - accuracy: 0.8163 Accuracy: 81.63%
现在,您的模型可以将原始字符串作为输入,并使用 Model.predict
def get_string_labels(predicted_scores_batch):
predicted_int_labels = tf.math.argmax(predicted_scores_batch, axis=1)
predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)
return predicted_labels
inputs = [
"how do I extract keys from a dict into a list?", # 'python'
"debug public static void main(string[] args) {...}", # 'java'
predicted_scores = export_model.predict(inputs)
predicted_labels = get_string_labels(predicted_scores)
for input, label in zip(inputs, predicted_labels):
print("Question: ", input)
print("Predicted label: ", label.numpy())
Question: how do I extract keys from a dict into a list? Predicted label: b'python' Question: debug public static void main(string[] args) {...} Predicted label: b'java'
在选择应用 tf.keras.layers.TextVectorization
层的位置时,需要注意性能差异。在模型之外使用它可以让您在 GPU 上训练时进行异步 CPU 处理和数据缓冲。因此,如果您在 GPU 上训练模型,您应该在开发模型时使用此选项以获得最佳性能,然后在准备好部署时进行切换,在模型中包含 TextVectorization
例 2:预测《伊利亚特》翻译的作者
下面提供了一个使用 tf.data.TextLineDataset
从文本文件中加载样本,以及使用 TensorFlow Text 预处理数据的示例。您将使用同一作品(荷马的《伊利亚特》)的三种不同英语翻译,训练一个模型来识别给定单行文本的译者。
DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']
for name in FILE_NAMES:
text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)
parent_dir = pathlib.Path(text_dir).parent
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt 819200/815980 [==============================] - 0s 0us/step 827392/815980 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt 811008/809730 [==============================] - 0s 0us/step 819200/809730 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt 811008/807992 [==============================] - 0s 0us/step 819200/807992 [==============================] - 0s 0us/step [PosixPath('/home/kbuilder/.keras/datasets/kandinsky5.jpg'), PosixPath('/home/kbuilder/.keras/datasets/spa-eng.zip'), PosixPath('/home/kbuilder/.keras/datasets/fashion-mnist'), PosixPath('/home/kbuilder/.keras/datasets/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg'), PosixPath('/home/kbuilder/.keras/datasets/facades'), PosixPath('/home/kbuilder/.keras/datasets/butler.txt'), PosixPath('/home/kbuilder/.keras/datasets/flower_photos'), PosixPath('/home/kbuilder/.keras/datasets/train.csv'), PosixPath('/home/kbuilder/.keras/datasets/facades.tar.gz'), PosixPath('/home/kbuilder/.keras/datasets/derby.txt'), PosixPath('/home/kbuilder/.keras/datasets/spa-eng'), PosixPath('/home/kbuilder/.keras/datasets/320px-Felis_catus-cat_on_snow.jpg'), PosixPath('/home/kbuilder/.keras/datasets/HIGGS.csv.gz'), PosixPath('/home/kbuilder/.keras/datasets/flower_photos.tar.gz'), PosixPath('/home/kbuilder/.keras/datasets/mnist.npz'), PosixPath('/home/kbuilder/.keras/datasets/YellowLabradorLooking_new.jpg'), PosixPath('/home/kbuilder/.keras/datasets/shakespeare.txt'), PosixPath('/home/kbuilder/.keras/datasets/cifar-10-batches-py.tar.gz'), PosixPath('/home/kbuilder/.keras/datasets/cowper.txt'), PosixPath('/home/kbuilder/.keras/datasets/iris_test.csv'), PosixPath('/home/kbuilder/.keras/datasets/cifar-10-batches-py'), PosixPath('/home/kbuilder/.keras/datasets/cats_and_dogs_filtered'), PosixPath('/home/kbuilder/.keras/datasets/eval.csv'), PosixPath('/home/kbuilder/.keras/datasets/cats_and_dogs.zip'), PosixPath('/home/kbuilder/.keras/datasets/iris_training.csv')]
以前,使用 tf.keras.utils.text_dataset_from_directory
时,文件的所有内容都会被视为单个样本。在这里,您将使用 tf.data.TextLineDataset
,它旨在从文本文件创建 tf.data.Dataset
遍历这些文件,将每个文件加载到自己的数据集中。每个样本都需要单独加标签,因此请使用 Dataset.map
为每个样本应用标签添加器功能。这将遍历数据集中的每个样本,同时返回 (example, label
) 对。
def labeler(example, index):
return example, tf.cast(index, tf.int64)
labeled_data_sets = []
for i, file_name in enumerate(FILE_NAMES):
lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))
labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
接下来,您将使用 Dataset.concatenate
将这些带标签的数据集组合到一个数据集中,并使用 Dataset.shuffle
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
all_labeled_data = all_labeled_data.concatenate(labeled_dataset)
all_labeled_data = all_labeled_data.shuffle(
BUFFER_SIZE, reshuffle_each_iteration=False)
像以前一样打印出几个样本。数据集尚未经过批处理,因此 all_labeled_data
for text, label in all_labeled_data.take(10):
print("Sentence: ", text.numpy())
print("Label:", label.numpy())
Sentence: b'And like two lofty pines in death they lay.' Label: 1 Sentence: b'Men and robed matrons, who shall seek the Gods' Label: 0 Sentence: b'So vaunted he, but Juno with disdain' Label: 0 Sentence: b'First turning, slew the mighty Bathycles,' Label: 1 Sentence: b'flames, while the women and children are carried into captivity; when' Label: 2 Sentence: b'finished. When they had heaped up the barrow they went back again into' Label: 2 Sentence: b"He said; and from th' applauding ranks of Greece" Label: 1 Sentence: b'the fate of Hector; at length Minerva descends to the aid of Achilles.' Label: 1 Sentence: b'sit down and keep your eyes on the horses; they are speeding towards' Label: 2 Sentence: b'dearest of his friends has fallen. But I can see not a man among the' Label: 2
现在,将不再使用 tf.keras.layers.TextVectorization
来预处理文本数据集,而是使用 TensorFlow Text API 对数据进行标准化和词例化、构建词汇表并使用 tf.lookup.StaticVocabularyTable
将词例映射到整数以馈送给模型。(详细了解 TensorFlow Text)。
- TensorFlow Text 提供各种分词器。在此示例中,您将使用
对数据集进行词例化。 - 您将使用
tokenizer = tf_text.UnicodeScriptTokenizer()
def tokenize(text, unused_label):
lower_case = tf_text.case_fold_utf8(text)
return tokenizer.tokenize(lower_case)
tokenized_ds = all_labeled_data.map(tokenize)
for text_batch in tokenized_ds.take(5):
print("Tokens: ", text_batch.numpy())
Tokens: [b'and' b'like' b'two' b'lofty' b'pines' b'in' b'death' b'they' b'lay' b'.'] Tokens: [b'men' b'and' b'robed' b'matrons' b',' b'who' b'shall' b'seek' b'the' b'gods'] Tokens: [b'so' b'vaunted' b'he' b',' b'but' b'juno' b'with' b'disdain'] Tokens: [b'first' b'turning' b',' b'slew' b'the' b'mighty' b'bathycles' b','] Tokens: [b'flames' b',' b'while' b'the' b'women' b'and' b'children' b'are' b'carried' b'into' b'captivity' b';' b'when']
接下来,您将通过按频率对词例进行排序并保留顶部 VOCAB_SIZE
tokenized_ds = configure_dataset(tokenized_ds)
vocab_dict = collections.defaultdict(lambda: 0)
for toks in tokenized_ds.as_numpy_iterator():
for tok in toks:
vocab_dict[tok] += 1
vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCAB_SIZE]
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)
print("First five vocab entries:", vocab[:5])
Vocab size: 10000 First five vocab entries: [b',', b'the', b'and', b"'", b'of']
要将词例转换为整数,请使用 vocab
集创建 tf.lookup.StaticVocabularyTable
。您将词例映射到 [2
, vocab_size + 2
] 范围内的整数。与 TextVectorization
层一样,保留 0
表示填充,保留 1
表示词汇表外 (OOV) 词例。
keys = vocab
values = range(2, len(vocab) + 2) # Reserve `0` for padding, `1` for OOV tokens.
init = tf.lookup.KeyValueTensorInitializer(
keys, values, key_dtype=tf.string, value_dtype=tf.int64)
num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
def preprocess_text(text, label):
standardized = tf_text.case_fold_utf8(text)
tokenized = tokenizer.tokenize(standardized)
vectorized = vocab_table.lookup(tokenized)
return vectorized, label
example_text, example_label = next(iter(all_labeled_data))
print("Sentence: ", example_text.numpy())
vectorized_text, example_label = preprocess_text(example_text, example_label)
print("Vectorized sentence: ", vectorized_text.numpy())
Sentence: b'And like two lofty pines in death they lay.' Vectorized sentence: [ 4 158 104 409 7806 13 134 27 249 7]
现在,使用 Dataset.map
all_encoded_data = all_labeled_data.map(preprocess_text)
Keras TextVectorization
层还会对向量化数据进行批处理和填充。填充是必需的,因为批次内的样本需要具有相同的大小和形状,但这些数据集中的样本并非全部相同 – 每行文本具有不同数量的单词。
train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
validation_data = all_encoded_data.take(VALIDATION_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE)
validation_data = validation_data.padded_batch(BATCH_SIZE)
和 train_data
不是 (example, label
) 对的集合,而是批次的集合。每个批次都是一对表示为数组的(许多样本、许多标签)。
sample_text, sample_labels = next(iter(validation_data))
print("Text batch shape: ", sample_text.shape)
print("Label batch shape: ", sample_labels.shape)
print("First text example: ", sample_text[0])
print("First label example: ", sample_labels[0])
Text batch shape: (64, 17) Label batch shape: (64,) First text example: tf.Tensor( [ 4 158 104 409 7806 13 134 27 249 7 0 0 0 0 0 0 0], shape=(17,), dtype=int64) First label example: tf.Tensor(1, shape=(), dtype=int64)
由于您将 0
用于填充,将 1
用于词汇外 (OOV) 词例,词汇量增加了两倍:
vocab_size += 2
train_data = configure_dataset(train_data)
validation_data = configure_dataset(validation_data)
model = create_model(vocab_size=vocab_size, num_labels=3)
history = model.fit(train_data, validation_data=validation_data, epochs=3)
Epoch 1/3 697/697 [==============================] - 29s 9ms/step - loss: 0.5210 - accuracy: 0.7706 - val_loss: 0.3762 - val_accuracy: 0.8382 Epoch 2/3 697/697 [==============================] - 3s 4ms/step - loss: 0.2883 - accuracy: 0.8832 - val_loss: 0.3643 - val_accuracy: 0.8482 Epoch 3/3 697/697 [==============================] - 3s 4ms/step - loss: 0.1962 - accuracy: 0.9250 - val_loss: 0.3924 - val_accuracy: 0.8470
loss, accuracy = model.evaluate(validation_data)
print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 1s 2ms/step - loss: 0.3924 - accuracy: 0.8470 Loss: 0.3924245238304138 Accuracy: 84.70%
为了使模型能够将原始字符串作为输入,您将创建一个 Keras TextVectorization
层,该层执行与您的自定义预处理函数相同的步骤。由于您已经训练了一个词汇表,可以使用 TextVectorization.set_vocabulary
(而不是 TextVectorization.adapt
preprocess_layer = TextVectorization(
export_model = tf.keras.Sequential(
[preprocess_layer, model,
# Create a test dataset of raw strings.
test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)
test_ds = configure_dataset(test_ds)
loss, accuracy = export_model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
2022-08-31 05:41:09.335295: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: sequential_4/text_vectorization_2/UnicodeScriptTokenize/Assert_1/AssertGuard/branch_executed/_185 79/79 [==============================] - 6s 8ms/step - loss: 0.5273 - accuracy: 0.7944 Loss: 0.5273075699806213 Accuracy: 79.44%
inputs = [
"Join'd to th' Ionians with their flowing robes,", # Label: 1
"the allies, and his armour flashed about him so that he seemed to all", # Label: 2
"And with loud clangor of his arms he fell.", # Label: 0
predicted_scores = export_model.predict(inputs)
predicted_labels = tf.math.argmax(predicted_scores, axis=1)
for input, label in zip(inputs, predicted_labels):
print("Question: ", input)
print("Predicted label: ", label.numpy())
2022-08-31 05:41:12.614755: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: sequential_4/text_vectorization_2/UnicodeScriptTokenize/Assert_1/AssertGuard/branch_executed/_185 Question: Join'd to th' Ionians with their flowing robes, Predicted label: 1 Question: the allies, and his armour flashed about him so that he seemed to all Predicted label: 2 Question: And with loud clangor of his arms he fell. Predicted label: 0
使用 TensorFlow Datasets (TFDS) 下载更多数据集
您可以从 TensorFlow Datasets 下载更多数据集。
在此示例中,您将使用 IMDB Large Movie Review Dataset 来训练情感分类模型:
# Training set.
train_ds = tfds.load(
# Validation set.
val_ds = tfds.load(
for review_batch, label_batch in val_ds.take(1):
for i in range(5):
print("Review: ", review_batch[i].numpy())
print("Label: ", label_batch[i].numpy())
Review: b"Instead, go to the zoo, buy some peanuts and feed 'em to the monkeys. Monkeys are funny. People with amnesia who don't say much, just sit there with vacant eyes are not all that funny.<br /><br />Black comedy? There isn't a black person in it, and there isn't one funny thing in it either.<br /><br />Walmart buys these things up somehow and puts them on their dollar rack. It's labeled Unrated. I think they took out the topless scene. They may have taken out other stuff too, who knows? All we know is that whatever they took out, isn't there any more.<br /><br />The acting seemed OK to me. There's a lot of unfathomables tho. It's supposed to be a city? It's supposed to be a big lake? If it's so hot in the church people are fanning themselves, why are they all wearing coats?" Label: 0 Review: b'Well, was Morgan Freeman any more unusual as God than George Burns? This film sure was better than that bore, "Oh, God". I was totally engrossed and LMAO all the way through. Carrey was perfect as the out of sorts anchorman wannabe, and Aniston carried off her part as the frustrated girlfriend in her usual well played performance. I, for one, don\'t consider her to be either ugly or untalented. I think my favorite scene was when Carrey opened up the file cabinet thinking it could never hold his life history. See if you can spot the file in the cabinet that holds the events of his bathroom humor: I was rolling over this one. Well written and even better played out, this comedy will go down as one of this funnyman\'s best.' Label: 1 Review: b'I remember stumbling upon this special while channel-surfing in 1965. I had never heard of Barbra before. When the show was over, I thought "This is probably the best thing on TV I will ever see in my life." 42 years later, that has held true. There is still nothing so amazing, so honestly astonishing as the talent that was displayed here. You can talk about all the super-stars you want to, this is the most superlative of them all!<br /><br />You name it, she can do it. Comedy, pathos, sultry seduction, ballads, Barbra is truly a story-teller. Her ability to pull off anything she attempts is legendary. But this special was made in the beginning, and helped to create the legend that she quickly became. In spite of rising so far in such a short time, she has fulfilled the promise, revealing more of her talents as she went along. But they are all here from the very beginning. You will not be disappointed in viewing this.' Label: 1 Review: b"Firstly, I would like to point out that people who have criticised this film have made some glaring errors. Anything that has a rating below 6/10 is clearly utter nonsense.<br /><br />Creep is an absolutely fantastic film with amazing film effects. The actors are highly believable, the narrative thought provoking and the horror and graphical content extremely disturbing. <br /><br />There is much mystique in this film. Many questions arise as the audience are revealed to the strange and freakish creature that makes habitat in the dark rat ridden tunnels. How was 'Craig' created and what happened to him?<br /><br />A fantastic film with a large chill factor. A film with so many unanswered questions and a film that needs to be appreciated along with others like 28 Days Later, The Bunker, Dog Soldiers and Deathwatch.<br /><br />Look forward to more of these fantastic films!!" Label: 1 Review: b"I'm sorry but I didn't like this doc very much. I can think of a million ways it could have been better. The people who made it obviously don't have much imagination. The interviews aren't very interesting and no real insight is offered. The footage isn't assembled in a very informative way, either. It's too bad because this is a movie that really deserves spellbinding special features. One thing I'll say is that Isabella Rosselini gets more beautiful the older she gets. All considered, this only gets a '4.'" Label: 0
注:您将对模型使用 tf.keras.losses.BinaryCrossentropy
而不是 tf.keras.losses.SparseCategoricalCrossentropy
vectorize_layer = TextVectorization(
# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.
train_text = train_ds.map(lambda text, labels: text)
def vectorize_text(text, label):
text = tf.expand_dims(text, -1)
return vectorize_layer(text), label
train_ds = train_ds.map(vectorize_text)
val_ds = val_ds.map(vectorize_text)
# Configure datasets for performance as before.
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)
model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_2 (Embedding) (None, None, 64) 640064 conv1d_2 (Conv1D) (None, None, 64) 20544 global_max_pooling1d_2 (Glo (None, 64) 0 balMaxPooling1D) dense_3 (Dense) (None, 1) 65 ================================================================= Total params: 660,673 Trainable params: 660,673 Non-trainable params: 0 _________________________________________________________________
history = model.fit(train_ds, validation_data=val_ds, epochs=3)
Epoch 1/3 313/313 [==============================] - 3s 7ms/step - loss: 0.5339 - accuracy: 0.6650 - val_loss: 0.3707 - val_accuracy: 0.8270 Epoch 2/3 313/313 [==============================] - 1s 4ms/step - loss: 0.2980 - accuracy: 0.8690 - val_loss: 0.3181 - val_accuracy: 0.8598 Epoch 3/3 313/313 [==============================] - 1s 4ms/step - loss: 0.1819 - accuracy: 0.9305 - val_loss: 0.3239 - val_accuracy: 0.8620
loss, accuracy = model.evaluate(val_ds)
print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 0s 2ms/step - loss: 0.3239 - accuracy: 0.8620 Loss: 0.32394784688949585 Accuracy: 86.20%
export_model = tf.keras.Sequential(
[vectorize_layer, model,
# 0 --> negative review
# 1 --> positive review
inputs = [
"This is a fantastic movie.",
"This is a bad movie.",
"This movie was so bad that it was good.",
"I will never say yes to watching this movie.",
predicted_scores = export_model.predict(inputs)
predicted_labels = [int(round(x[0])) for x in predicted_scores]
for input, label in zip(inputs, predicted_labels):
print("Question: ", input)
print("Predicted label: ", label)
Question: This is a fantastic movie. Predicted label: 1 Question: This is a bad movie. Predicted label: 0 Question: This movie was so bad that it was good. Predicted label: 0 Question: I will never say yes to watching this movie. Predicted label: 0
本教程演示了几种加载和预处理文本的方法。接下来,您可以探索其他文本预处理 TensorFlow Text 教程,例如:
