Text-to-Video retrieval with S3D MIL-NCE
Stay organized with collections
Save and categorize content based on your preferences.
!pip install -q opencv-python
import os
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
import numpy as np
import cv2
from IPython import display
import math
Import TF-Hub model
This tutorial demonstrates how to use the S3D MIL-NCE model from TensorFlow Hub to do text-to-video retrieval to find the most similar videos for a given text query.
The model has 2 signatures, one for generating video embeddings and one for generating text embeddings. We will use these embedding to find the nearest neighbors in the embedding space.
# Load the model once from TF-Hub.
hub_handle = 'https://tfhub.dev/deepmind/mil-nce/s3d/1'
hub_model = hub.load(hub_handle)
def generate_embeddings(model, input_frames, input_words):
"""Generate embeddings from the model from video frames and input words."""
# Input_frames must be normalized in [0, 1] and of the shape Batch x T x H x W x 3
vision_output = model.signatures['video'](tf.constant(tf.cast(input_frames, dtype=tf.float32)))
text_output = model.signatures['text'](tf.constant(input_words))
return vision_output['video_embedding'], text_output['text_embedding']
2024-03-09 14:50:17.063759: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
# @title Define video loading and visualization functions { display-mode: "form" }
# Utilities to open video files using CV2
def crop_center_square(frame):
y, x = frame.shape[0:2]
min_dim = min(y, x)
start_x = (x // 2) - (min_dim // 2)
start_y = (y // 2) - (min_dim // 2)
return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]
def load_video(video_url, max_frames=32, resize=(224, 224)):
path = tf.keras.utils.get_file(os.path.basename(video_url)[-128:], video_url)
cap = cv2.VideoCapture(path)
frames = []
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame = crop_center_square(frame)
frame = cv2.resize(frame, resize)
frame = frame[:, :, [2, 1, 0]]
frames.append(frame)
if len(frames) == max_frames:
break
finally:
cap.release()
frames = np.array(frames)
if len(frames) < max_frames:
n_repeat = int(math.ceil(max_frames / float(len(frames))))
frames = frames.repeat(n_repeat, axis=0)
frames = frames[:max_frames]
return frames / 255.0
def display_video(urls):
html = '<table>'
html += '<tr><th>Video 1</th><th>Video 2</th><th>Video 3</th></tr><tr>'
for url in urls:
html += '<td>'
html += '<img src="{}" height="224">'.format(url)
html += '</td>'
html += '</tr></table>'
return display.HTML(html)
def display_query_and_results_video(query, urls, scores):
"""Display a text query and the top result videos and scores."""
sorted_ix = np.argsort(-scores)
html = ''
html += '<h2>Input query: <i>{}</i> </h2><div>'.format(query)
html += 'Results: <div>'
html += '<table>'
html += '<tr><th>Rank #1, Score:{:.2f}</th>'.format(scores[sorted_ix[0]])
html += '<th>Rank #2, Score:{:.2f}</th>'.format(scores[sorted_ix[1]])
html += '<th>Rank #3, Score:{:.2f}</th></tr><tr>'.format(scores[sorted_ix[2]])
for i, idx in enumerate(sorted_ix):
url = urls[sorted_ix[i]];
html += '<td>'
html += '<img src="{}" height="224">'.format(url)
html += '</td>'
html += '</tr></table>'
return html
# @title Load example videos and define text queries { display-mode: "form" }
video_1_url = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif' # @param {type:"string"}
video_2_url = 'https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif' # @param {type:"string"}
video_3_url = 'https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif' # @param {type:"string"}
video_1 = load_video(video_1_url)
video_2 = load_video(video_2_url)
video_3 = load_video(video_3_url)
all_videos = [video_1, video_2, video_3]
query_1_video = 'waterfall' # @param {type:"string"}
query_2_video = 'playing guitar' # @param {type:"string"}
query_3_video = 'car drifting' # @param {type:"string"}
all_queries_video = [query_1_video, query_2_video, query_3_video]
all_videos_urls = [video_1_url, video_2_url, video_3_url]
display_video(all_videos_urls)
Downloading data from https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif
1207385/1207385 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif
1021622/1021622 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif
1506603/1506603 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Demonstrate text to video retrieval
# Prepare video inputs.
videos_np = np.stack(all_videos, axis=0)
# Prepare text input.
words_np = np.array(all_queries_video)
# Generate the video and text embeddings.
video_embd, text_embd = generate_embeddings(hub_model, videos_np, words_np)
# Scores between video and text is computed by dot products.
all_scores = np.dot(text_embd, tf.transpose(video_embd))
# Display results.
html = ''
for i, words in enumerate(words_np):
html += display_query_and_results_video(words, all_videos_urls, all_scores[i, :])
html += '<br>'
display.HTML(html)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-03-09 UTC.
[null,null,["Last updated 2024-03-09 UTC."],[],[],null,["# Text-to-Video retrieval with S3D MIL-NCE\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------|\n| [View on TensorFlow.org](https://www.tensorflow.org/hub/tutorials/text_to_video_retrieval_with_s3d_milnce) | [Run in Google Colab](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/text_to_video_retrieval_with_s3d_milnce.ipynb) | [View on GitHub](https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/text_to_video_retrieval_with_s3d_milnce.ipynb) | [Download notebook](https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/text_to_video_retrieval_with_s3d_milnce.ipynb) | [See TF Hub model](https://tfhub.dev/deepmind/mil-nce/s3d/1) |\n\n !pip install -q opencv-python\n\n import os\n\n import tensorflow.compat.v2 as tf\n import tensorflow_hub as hub\n\n import numpy as np\n import cv2\n from IPython import display\n import math\n\nImport TF-Hub model\n-------------------\n\nThis tutorial demonstrates how to use the [S3D MIL-NCE model](https://tfhub.dev/deepmind/mil-nce/s3d/1) from TensorFlow Hub to do **text-to-video retrieval** to find the most similar videos for a given text query.\n\nThe model has 2 signatures, one for generating *video embeddings* and one for generating *text embeddings*. We will use these embedding to find the nearest neighbors in the embedding space. \n\n # Load the model once from TF-Hub.\n hub_handle = 'https://tfhub.dev/deepmind/mil-nce/s3d/1'\n hub_model = hub.load(hub_handle)\n\n def generate_embeddings(model, input_frames, input_words):\n \"\"\"Generate embeddings from the model from video frames and input words.\"\"\"\n # Input_frames must be normalized in [0, 1] and of the shape Batch x T x H x W x 3\n vision_output = model.signatures['video'](tf.constant(tf.cast(input_frames, dtype=tf.float32)))\n text_output = model.signatures['text'](tf.constant(input_words))\n return vision_output['video_embedding'], text_output['text_embedding']\n\n```\n2024-03-09 14:50:17.063759: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n``` \n\n # @title Define video loading and visualization functions { display-mode: \"form\" }\n\n # Utilities to open video files using CV2\n def crop_center_square(frame):\n y, x = frame.shape[0:2]\n min_dim = min(y, x)\n start_x = (x // 2) - (min_dim // 2)\n start_y = (y // 2) - (min_dim // 2)\n return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]\n\n\n def load_video(video_url, max_frames=32, resize=(224, 224)):\n path = tf.keras.utils.get_file(os.path.basename(video_url)[-128:], video_url)\n cap = cv2.VideoCapture(path)\n frames = []\n try:\n while True:\n ret, frame = cap.read()\n if not ret:\n break\n frame = crop_center_square(frame)\n frame = cv2.resize(frame, resize)\n frame = frame[:, :, [2, 1, 0]]\n frames.append(frame)\n\n if len(frames) == max_frames:\n break\n finally:\n cap.release()\n frames = np.array(frames)\n if len(frames) \u003c max_frames:\n n_repeat = int(math.ceil(max_frames / float(len(frames))))\n frames = frames.repeat(n_repeat, axis=0)\n frames = frames[:max_frames]\n return frames / 255.0\n\n def display_video(urls):\n html = '\u003ctable\u003e'\n html += '\u003ctr\u003e\u003cth\u003eVideo 1\u003c/th\u003e\u003cth\u003eVideo 2\u003c/th\u003e\u003cth\u003eVideo 3\u003c/th\u003e\u003c/tr\u003e\u003ctr\u003e'\n for url in urls:\n html += '\u003ctd\u003e'\n html += '\u003cimg src=\"{}\" height=\"224\"\u003e'.format(url)\n html += '\u003c/td\u003e'\n html += '\u003c/tr\u003e\u003c/table\u003e'\n return display.HTML(html)\n\n def display_query_and_results_video(query, urls, scores):\n \"\"\"Display a text query and the top result videos and scores.\"\"\"\n sorted_ix = np.argsort(-scores)\n html = ''\n html += '\u003ch2\u003eInput query: \u003ci\u003e{}\u003c/i\u003e \u003c/h2\u003e\u003cdiv\u003e'.format(query)\n html += 'Results: \u003cdiv\u003e'\n html += '\u003ctable\u003e'\n html += '\u003ctr\u003e\u003cth\u003eRank #1, Score:{:.2f}\u003c/th\u003e'.format(scores[sorted_ix[0]])\n html += '\u003cth\u003eRank #2, Score:{:.2f}\u003c/th\u003e'.format(scores[sorted_ix[1]])\n html += '\u003cth\u003eRank #3, Score:{:.2f}\u003c/th\u003e\u003c/tr\u003e\u003ctr\u003e'.format(scores[sorted_ix[2]])\n for i, idx in enumerate(sorted_ix):\n url = urls[sorted_ix[i]];\n html += '\u003ctd\u003e'\n html += '\u003cimg src=\"{}\" height=\"224\"\u003e'.format(url)\n html += '\u003c/td\u003e'\n html += '\u003c/tr\u003e\u003c/table\u003e'\n return html\n\n # @title Load example videos and define text queries { display-mode: \"form\" }\n\n video_1_url = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif' # @param {type:\"string\"}\n video_2_url = 'https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif' # @param {type:\"string\"}\n video_3_url = 'https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif' # @param {type:\"string\"}\n\n video_1 = load_video(video_1_url)\n video_2 = load_video(video_2_url)\n video_3 = load_video(video_3_url)\n all_videos = [video_1, video_2, video_3]\n\n query_1_video = 'waterfall' # @param {type:\"string\"}\n query_2_video = 'playing guitar' # @param {type:\"string\"}\n query_3_video = 'car drifting' # @param {type:\"string\"}\n all_queries_video = [query_1_video, query_2_video, query_3_video]\n all_videos_urls = [video_1_url, video_2_url, video_3_url]\n display_video(all_videos_urls)\n\n```\nDownloading data from https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif\n1207385/1207385 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step\nDownloading data from https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif\n1021622/1021622 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step\nDownloading data from https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif\n1506603/1506603 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step\n```\n\n| Video 1 | Video 2 | Video 3 |\n|---------|---------|---------|\n| | | |\n\nDemonstrate text to video retrieval\n-----------------------------------\n\n # Prepare video inputs.\n videos_np = np.stack(all_videos, axis=0)\n\n # Prepare text input.\n words_np = np.array(all_queries_video)\n\n # Generate the video and text embeddings.\n video_embd, text_embd = generate_embeddings(hub_model, videos_np, words_np)\n\n # Scores between video and text is computed by dot products.\n all_scores = np.dot(text_embd, tf.transpose(video_embd))\n\n # Display results.\n html = ''\n for i, words in enumerate(words_np):\n html += display_query_and_results_video(words, all_videos_urls, all_scores[i, :])\n html += '\u003cbr\u003e'\n display.HTML(html)\n\nInput query: *waterfall*\n------------------------\n\nResults: \n\n| Rank #1, Score:4.71 | Rank #2, Score:-1.63 | Rank #3, Score:-4.17 |\n|---------------------|----------------------|----------------------|\n| | | |\n\n\u003cbr /\u003e\n\nInput query: *playing guitar*\n-----------------------------\n\nResults: \n\n| Rank #1, Score:6.50 | Rank #2, Score:-1.79 | Rank #3, Score:-2.67 |\n|---------------------|----------------------|----------------------|\n| | | |\n\n\u003cbr /\u003e\n\nInput query: *car drifting*\n---------------------------\n\nResults: \n\n| Rank #1, Score:8.78 | Rank #2, Score:-1.07 | Rank #3, Score:-2.17 |\n|---------------------|----------------------|----------------------|\n| | | |\n\n\u003cbr /\u003e"]]