tfr.keras.pipeline.MultiLabelDatasetBuilder
Stay organized with collections
Save and categorize content based on your preferences.
Builds datasets for multi-task training.
Inherits From: BaseDatasetBuilder
, AbstractDatasetBuilder
tfr.keras.pipeline.MultiLabelDatasetBuilder(
context_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.
RaggedFeature]],
example_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.
RaggedFeature]],
mask_feature_name: str,
label_spec: Dict[str, Tuple[str, tf.io.FixedLenFeature]],
hparams: tfr.keras.pipeline.DatasetHparams
,
sample_weight_spec: Optional[Tuple[str, tf.io.FixedLenFeature]] = None
)
This supports a single data sets with multiple labels formed in a dict. The
case where we have multiple datasets is not handled in the current code yet.
We can consider to extend the dataset builder when the use case comes out.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
tf.io.FixedLenFeature(
shape=(1,),
dtype=tf.float32,
default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
dataset_builder = MultiLabelDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams,
sample_weight_spec=weight_spec)
Args |
context_feature_spec
|
Maps context (aka, query) names to feature specs.
|
example_feature_spec
|
Maps example (aka, document) names to feature specs.
|
mask_feature_name
|
If set, populates the feature dictionary with this name
and the coresponding value is a tf.bool Tensor of shape [batch_size,
list_size] indicating the actual example is padded or not.
|
label_spec
|
A dict that maps task names to label specs. Each of the latter
have a label name and a tf.io.FixedLenFeature spec.
|
hparams
|
A dict containing model hyperparameters.
|
sample_weight_spec
|
Feature spec for per-example weight.
|
Methods
build_signatures
View source
build_signatures(
model: tf.keras.Model
) -> Any
See AbstractDatasetBuilder
.
build_train_dataset
View source
build_train_dataset() -> tf.data.Dataset
See AbstractDatasetBuilder
.
build_valid_dataset
View source
build_valid_dataset() -> tf.data.Dataset
See AbstractDatasetBuilder
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-08-18 UTC.
[null,null,["Last updated 2023-08-18 UTC."],[],[],null,["# tfr.keras.pipeline.MultiLabelDatasetBuilder\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L1120-L1221) |\n\nBuilds datasets for multi-task training.\n\nInherits From: [`BaseDatasetBuilder`](../../../tfr/keras/pipeline/BaseDatasetBuilder), [`AbstractDatasetBuilder`](../../../tfr/keras/pipeline/AbstractDatasetBuilder) \n\n tfr.keras.pipeline.MultiLabelDatasetBuilder(\n context_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.\n RaggedFeature]],\n example_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.\n RaggedFeature]],\n mask_feature_name: str,\n label_spec: Dict[str, Tuple[str, tf.io.FixedLenFeature]],\n hparams: ../../../tfr/keras/pipeline/DatasetHparams,\n sample_weight_spec: Optional[Tuple[str, tf.io.FixedLenFeature]] = None\n )\n\nThis supports a single data sets with multiple labels formed in a dict. The\ncase where we have multiple datasets is not handled in the current code yet.\nWe can consider to extend the dataset builder when the use case comes out.\n\n#### Example usage:\n\n context_feature_spec = {}\n example_feature_spec = {\n \"example_feature_1\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n mask_feature_name = \"list_mask\"\n label_spec_tuple = (\"utility\",\n tf.io.FixedLenFeature(\n shape=(1,),\n dtype=tf.float32,\n default_value=_PADDING_LABEL))\n label_spec = {\"task1\": label_spec_tuple, \"task2\": label_spec_tuple}\n weight_spec = (\"weight\",\n tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=1.))\n dataset_hparams = DatasetHparams(\n train_input_pattern=\"train.dat\",\n valid_input_pattern=\"valid.dat\",\n train_batch_size=128,\n valid_batch_size=128)\n dataset_builder = MultiLabelDatasetBuilder(\n context_feature_spec,\n example_feature_spec,\n mask_feature_name,\n label_spec,\n dataset_hparams,\n sample_weight_spec=weight_spec)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `context_feature_spec` | Maps context (aka, query) names to feature specs. |\n| `example_feature_spec` | Maps example (aka, document) names to feature specs. |\n| `mask_feature_name` | If set, populates the feature dictionary with this name and the coresponding value is a [`tf.bool`](https://www.tensorflow.org/api_docs/python/tf#bool) Tensor of shape \\[batch_size, list_size\\] indicating the actual example is padded or not. |\n| `label_spec` | A dict that maps task names to label specs. Each of the latter have a label name and a tf.io.FixedLenFeature spec. |\n| `hparams` | A dict containing model hyperparameters. |\n| `sample_weight_spec` | Feature spec for per-example weight. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `build_signatures`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L1001-L1007) \n\n build_signatures(\n model: tf.keras.Model\n ) -\u003e Any\n\nSee `AbstractDatasetBuilder`.\n\n### `build_train_dataset`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L984-L990) \n\n build_train_dataset() -\u003e tf.data.Dataset\n\nSee `AbstractDatasetBuilder`.\n\n### `build_valid_dataset`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L992-L999) \n\n build_valid_dataset() -\u003e tf.data.Dataset\n\nSee `AbstractDatasetBuilder`."]]