defconceptual_captions(*,data_dir="conceptual_captions",num_train,num_val):defiter_index(index_path):withopen(index_path)asf:forlineinf:caption,url=line.strip().split('\t')yieldcaption,urldefdownload_image_urls(data_dir,urls):ex=concurrent.futures.ThreadPoolExecutor(max_workers=100)defsave_image(url):hash=hashlib.sha1(url.encode())# Name the files after the hash of the URL.file_path=data_dir/f'{hash.hexdigest()}.jpeg'iffile_path.exists():# Only download each file once.returnfile_pathtry:result=requests.get(url,timeout=5)exceptException:file_path=Noneelse:file_path.write_bytes(result.content)returnfile_pathresult=[]out_paths=ex.map(save_image,urls)forfile_pathintqdm.tqdm(out_paths,total=len(urls)):result.append(file_path)returnresultdefds_from_index_file(index_path,data_dir,count):data_dir.mkdir(exist_ok=True)index=list(itertools.islice(iter_index(index_path),count))captions=[captionforcaption,urlinindex]urls=[urlforcaption,urlinindex]paths=download_image_urls(data_dir,urls)new_captions=[]new_paths=[]forcap,pathinzip(captions,paths):ifpathisNone:# Download failed, so skip this pair.continuenew_captions.append(cap)new_paths.append(path)new_paths=[str(p)forpinnew_paths]ds=tf.data.Dataset.from_tensor_slices((new_paths,new_captions))ds=ds.map(lambdapath,cap:(path,cap[tf.newaxis]))# 1 caption per imagereturndsdata_dir=pathlib.Path(data_dir)train_index_path=tf.keras.utils.get_file(origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',cache_subdir=data_dir,cache_dir='.')val_index_path=tf.keras.utils.get_file(origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',cache_subdir=data_dir,cache_dir='.')train_raw=ds_from_index_file(train_index_path,data_dir=data_dir/'train',count=num_train)test_raw=ds_from_index_file(val_index_path,data_dir=data_dir/'val',count=num_val)returntrain_raw,test_raw
# Use the top 5000 words for a vocabulary.
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
max_tokens=vocabulary_size,
standardize=standardize,
ragged=True)
# Learn the vocabulary from the caption data.
t = tokenizer([['a cat in a hat'], ['a robot dog']])
t
# Create mappings for words to indices and indices to words.
word_to_index = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=tokenizer.get_vocabulary(),
invert=True)
defprepare_dataset(ds,tokenizer,batch_size=32,shuffle_buffer=1000):# Load the images and make batches.ds=(ds.shuffle(10000).map(lambdapath,caption:(load_image(path),caption)).apply(tf.data.experimental.ignore_errors()).batch(batch_size))defto_tensor(inputs,labels):(images,in_tok),out_tok=inputs,labelsreturn(images,in_tok.to_tensor()),out_tok.to_tensor()return(ds.map(match_shapes,tf.data.AUTOTUNE).unbatch().shuffle(shuffle_buffer).batch(batch_size).map(prepare_txt,tf.data.AUTOTUNE).map(to_tensor,tf.data.AUTOTUNE))
defsave_dataset(ds,save_path,image_model,tokenizer,shards=10,batch_size=32):# Load the images and make batches.ds=(ds.map(lambdapath,caption:(load_image(path),caption)).apply(tf.data.experimental.ignore_errors()).batch(batch_size))# Run the feature extractor on each batch# Don't do this in a .map, because tf.data runs on the CPU. defgen():for(images,captions)intqdm.tqdm(ds):feature_maps=image_model(images)feature_maps,captions=match_shapes(feature_maps,captions)yieldfeature_maps,captions# Wrap the generator in a new tf.data.Dataset.new_ds=tf.data.Dataset.from_generator(gen,output_signature=(tf.TensorSpec(shape=image_model.output_shape),tf.TensorSpec(shape=(None,),dtype=tf.string)))# Apply the tokenization new_ds=(new_ds.map(prepare_txt,tf.data.AUTOTUNE).unbatch().shuffle(1000))# Save the dataset into shard files.defshard_func(i,item):returni%shardsnew_ds.enumerate().save(save_path,shard_func=shard_func)defload_dataset(save_path,batch_size=32,shuffle=1000,cycle_length=2):defcustom_reader_func(datasets):datasets=datasets.shuffle(1000)returndatasets.interleave(lambdax:x,cycle_length=cycle_length)ds=tf.data.Dataset.load(save_path,reader_func=custom_reader_func)defdrop_index(i,x):returnxds=(ds.map(drop_index,tf.data.AUTOTUNE).shuffle(shuffle).padded_batch(batch_size).prefetch(tf.data.AUTOTUNE))returnds
classTokenOutput(tf.keras.layers.Layer):def__init__(self,tokenizer,banned_tokens=('','[UNK]','[START]'),**kwargs):super().__init__()self.dense=tf.keras.layers.Dense(units=tokenizer.vocabulary_size(),**kwargs)self.tokenizer=tokenizerself.banned_tokens=banned_tokensself.bias=Nonedefadapt(self,ds):counts=collections.Counter()vocab_dict={name:idforid,nameinenumerate(self.tokenizer.get_vocabulary())}fortokensintqdm.tqdm(ds):counts.update(tokens.numpy().flatten())counts_arr=np.zeros(shape=(self.tokenizer.vocabulary_size(),))counts_arr[np.array(list(counts.keys()), dtype=np.int32)]=list(counts.values())counts_arr=counts_arr[:]fortokeninself.banned_tokens:counts_arr[vocab_dict[token]]=0total=counts_arr.sum()p=counts_arr/totalp[counts_arr==0]=1.0log_p=np.log(p)#log(1)==0entropy=-(log_p*p).sum()print()print(f"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}")print(f"Marginal entropy: {entropy:0.2f}")self.bias=log_pself.bias[counts_arr==0]=-1e9defcall(self,x):x=self.dense(x)#TODO(b/250038731):Fixthis.#AnAddlayerdoesn't work because of the different shapes. # This clears the mask, that'sokaybecauseitpreventskerasfromrescaling#thelosses.returnx+self.bias
seq_embedding 层,将词例 ID 批次转换为向量 (batch, sequence, channels)。
将处理文本和图像数据的 DecoderLayers 层堆叠。
output_layer 返回下一个字词应该是什么的逐点预测。
classCaptioner(tf.keras.Model):@classmethoddefadd_method(cls,fun):setattr(cls,fun.__name__,fun)returnfundef__init__(self,tokenizer,feature_extractor,output_layer,num_layers=1,units=256,max_length=50,num_heads=1,dropout_rate=0.1):super().__init__()self.feature_extractor=feature_extractorself.tokenizer=tokenizerself.word_to_index=tf.keras.layers.StringLookup(mask_token="",vocabulary=tokenizer.get_vocabulary())self.index_to_word=tf.keras.layers.StringLookup(mask_token="",vocabulary=tokenizer.get_vocabulary(),invert=True)self.seq_embedding=SeqEmbedding(vocab_size=tokenizer.vocabulary_size(),depth=units,max_length=max_length)self.decoder_layers=[ DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate) for n in range(num_layers)]self.output_layer=output_layer