在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 | 查看 TF Hub 模型 |
帧插值是从一组给定图像合成许多中间图像的任务。这项技术通常用于帧速率上采样或创建慢动作视频效果。
在此 Colab 中,您将使用 FILM 模型进行帧插值。Colab 还提供了用于从插值的中间图像创建视频的代码段。
有关 FILM 研究的更多信息,可以在此处阅读更多内容:
- Google AI 博客:Large Motion Frame Interpolation
- FILM 项目页面:Frame Interpolation for Large Motion
安装
pip install mediapy
sudo apt-get install -y ffmpeg
import tensorflow as tf
import tensorflow_hub as hub
import requests
import numpy as np
from typing import Generator, Iterable, List, Optional
import mediapy as media
2022-12-14 22:24:12.710093: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:24:12.710187: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:24:12.710196: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
从 TFHub 加载模型
要从 TensorFlow Hub 加载模型,您需要 tfhub 库和模型句柄,即它的文档 URL。
model = hub.load("https://tfhub.dev/google/film/1")
从 URL 或本地加载图像的效用函数
此函数可以加载图像并使其准备好供模型稍后使用。
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
def load_image(img_url: str):
"""Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32."""
if (img_url.startswith("https")):
user_agent = {'User-agent': 'Colab Sample (https://tensorflow.org)'}
response = requests.get(img_url, headers=user_agent)
image_data = response.content
else:
image_data = tf.io.read_file(img_url)
image = tf.io.decode_image(image_data, channels=3)
image_numpy = tf.cast(image, dtype=tf.float32).numpy()
return image_numpy / _UINT8_MAX_F
FILM 的模型输入是一个包含键 time
、x0
、x1
的字典:
time
:插值帧的位置。中间为0.5
。x0
:初始帧。x1
:最后一帧。
两个帧都需要归一化(在上面的函数 load_image
中完成),其中每个像素都处于 [0..1]
范围内。
time
是 [0..1]
之间的一个值,它表示生成的图像应该位于何处。0.5 是输入图像之间的中间值。
全部三个值也需要有一个批次维度。
# using images from the FILM repository (https://github.com/google-research/frame-interpolation/)
image_1_url = "https://github.com/google-research/frame-interpolation/blob/main/photos/one.png?raw=true"
image_2_url = "https://github.com/google-research/frame-interpolation/blob/main/photos/two.png?raw=true"
time = np.array([0.5], dtype=np.float32)
image1 = load_image(image_1_url)
image2 = load_image(image_2_url)
input = {
'time': np.expand_dims(time, axis=0), # adding the batch dimension to the time
'x0': np.expand_dims(image1, axis=0), # adding the batch dimension to the image
'x1': np.expand_dims(image2, axis=0) # adding the batch dimension to the image
}
mid_frame = model(input)
2022-12-14 22:24:30.715632: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.78GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
该模型输出了几个结果,但您将在此处使用的是 image
键,其值为插值帧。
print(mid_frame.keys())
dict_keys(['image', 'x0_warped', 'forward_residual_flow_pyramid', 'forward_flow_pyramid', 'x1_warped', 'backward_flow_pyramid', 'backward_residual_flow_pyramid'])
frames = [image1, mid_frame['image'][0].numpy(), image2]
media.show_images(frames, titles=['input image one', 'generated image', 'input image two'], height=250)
我们从生成的帧创建一个视频
media.show_video(frames, fps=3, title='FILM interpolated video')
定义帧插值器库
如您所见,过渡不太流畅。
要改进这一点,您需要更多插值帧。
您可以使用中间图像多次运行模型,但有更好的解决方案。
要生成许多插值图像并获得更流畅的视频,您可以创建一个插值器库。
"""A wrapper class for running a frame interpolation based on the FILM model on TFHub
Usage:
interpolator = Interpolator()
result_batch = interpolator(image_batch_0, image_batch_1, batch_dt)
Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
(B,H,W,C) layout, batch_dt is the sub-frame time in range [0..1], (B,) layout.
"""
def _pad_to_align(x, align):
"""Pads image batch x so width and height divide by align.
Args:
x: Image batch to align.
align: Number to align to.
Returns:
1) An image padded so width % align == 0 and height % align == 0.
2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
to undo the padding.
"""
# Input checking.
assert np.ndim(x) == 4
assert align > 0, 'align must be a positive number.'
height, width = x.shape[-3:-1]
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
bbox_to_pad = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height + height_to_pad,
'target_width': width + width_to_pad
}
padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
bbox_to_crop = {
'offset_height': height_to_pad // 2,
'offset_width': width_to_pad // 2,
'target_height': height,
'target_width': width
}
return padded_x, bbox_to_crop
class Interpolator:
"""A class for generating interpolated frames between two input frames.
Uses the Film model from TFHub
"""
def __init__(self, align: int = 64) -> None:
"""Loads a saved model.
Args:
align: 'If >1, pad the input size so it divides with this before
inference.'
"""
self._model = hub.load("https://tfhub.dev/google/film/1")
self._align = align
def __call__(self, x0: np.ndarray, x1: np.ndarray,
dt: np.ndarray) -> np.ndarray:
"""Generates an interpolated frame between given two batches of frames.
All inputs should be np.float32 datatype.
Args:
x0: First image batch. Dimensions: (batch_size, height, width, channels)
x1: Second image batch. Dimensions: (batch_size, height, width, channels)
dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
Returns:
The result with dimensions (batch_size, height, width, channels).
"""
if self._align is not None:
x0, bbox_to_crop = _pad_to_align(x0, self._align)
x1, _ = _pad_to_align(x1, self._align)
inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
result = self._model(inputs, training=False)
image = result['image']
if self._align is not None:
image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
return image.numpy()
帧和视频生成效用函数
def _recursive_generator(
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
interpolator: Interpolator) -> Generator[np.ndarray, None, None]:
"""Splits halfway to repeatedly generate more frames.
Args:
frame1: Input image 1.
frame2: Input image 2.
num_recursions: How many times to interpolate the consecutive image pairs.
interpolator: The frame interpolator instance.
Yields:
The interpolated frames, including the first frame (frame1), but excluding
the final frame2.
"""
if num_recursions == 0:
yield frame1
else:
# Adds the batch dimension to all inputs before calling the interpolator,
# and remove it afterwards.
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
mid_frame = interpolator(
np.expand_dims(frame1, axis=0), np.expand_dims(frame2, axis=0), time)[0]
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1,
interpolator)
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1,
interpolator)
def interpolate_recursively(
frames: List[np.ndarray], num_recursions: int,
interpolator: Interpolator) -> Iterable[np.ndarray]:
"""Generates interpolated frames by repeatedly interpolating the midpoint.
Args:
frames: List of input frames. Expected shape (H, W, 3). The colors should be
in the range[0, 1] and in gamma space.
num_recursions: Number of times to do recursive midpoint
interpolation.
interpolator: The frame interpolation model to use.
Yields:
The interpolated frames (including the inputs).
"""
n = len(frames)
for i in range(1, n):
yield from _recursive_generator(frames[i - 1], frames[i],
times_to_interpolate, interpolator)
# Separately yield the final frame.
yield frames[-1]
times_to_interpolate = 6
interpolator = Interpolator()
运行插值器
input_frames = [image1, image2]
frames = list(
interpolate_recursively(input_frames, times_to_interpolate,
interpolator))
print(f'video with {len(frames)} frames')
media.show_video(frames, fps=30, title='FILM interpolated video')
video with 65 frames
有关详情,可以访问 FILM 的模型仓库。
引用
如果您发现此模型和代码对您的工作有用,请通过引用以下代码适当地致谢:
@inproceedings{reda2022film,
title = {FILM: Frame Interpolation for Large Motion},
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
booktitle = {The European Conference on Computer Vision (ECCV)},
year = {2022}
}
@misc{film-tf,
title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/google-research/frame-interpolation} }
}