View source on GitHub
|
MinDiffKernel abstract base class.
model_remediation.min_diff.losses.MinDiffKernel(
tile_input: bool = True
)
Arguments | |
|---|---|
tile_input
|
Boolean indicating whether to tile inputs before computing the kernel (see below for details). |
To be implemented by subclasses:
call(): contains the logic for the kernel tensor calculation.
Example subclass Implementation:
class GuassKernel(MinDiffKernel):
def call(x, y):
return tf.exp(-tf.reduce_sum(tf.square(x - y), axis=2) / 0.01)
"Tiling" is a way of expanding the rank of the input tensors so that their dimensions work for the operations we need.
If x and y are of rank [N, D] and [M, D] respectively, tiling expands
them to be: [N, ?, D] and [?, M, D] where tf broadcasting will ensure
that the operations between them work.
Methods
call
@abc.abstractmethodcall( x: types.TensorType, y: types.TensorType )
Invokes the MinDiffKernel instance.
| Arguments | |
|---|---|
x
|
tf.Tensor of shape [N, M, D].
|
y
|
tf.Tensor of shape [N, M, D].
|
This method contains the logic for computing the kernel. It must be implemented by subclasses.
| Returns | |
|---|---|
tf.Tensor of shape [N, M].
|
from_config
@classmethodfrom_config( config )
Creates a MinDiffKernel instance fron the config.
Any subclass with additional attributes or a different initialization
signature will need to override this method or get_config.
| Returns | |
|---|---|
A new MinDiffKernel instance corresponding to config.
|
get_config
get_config()
Creates a config dictionary for the MinDiffKernel instance.
Any subclass with additional attributes will need to override this method.
When doing so, users will mostly likely want to first call super.
| Returns | |
|---|---|
A config dictionary for the MinDiffKernel isinstance.
|
__call__
__call__(
x: types.TensorType, y: Optional[types.TensorType] = None
) -> types.TensorType
Invokes the kernel instance.
| Arguments | |
|---|---|
x
|
tf.Tensor of shape [N, D] (if tiling input) or [N, M, D] (if not
tiling input).
|
y
|
Optional tf.Tensor of shape [M, D] (if tiling input) or [N, M, D]
(if not tiling input).
|
If y is None, it is set to be the same as x:
if y is None:
y = x
Inputs are tiled if self.tile_input == True and left as is otherwise.
| Returns | |
|---|---|
tf.Tensor of shape [N, M].
|
View source on GitHub