/
op.h
326 lines (275 loc) · 12 KB
/
op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
#define TENSORFLOW_CORE_FRAMEWORK_OP_H_
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/full_type_inference_util.h" // IWYU pragma: export
#include "tensorflow/core/framework/full_type_util.h" // IWYU pragma: export
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h" // IWYU pragma: export
#include "tensorflow/core/framework/registration/registration.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Users that want to look up an OpDef by type name should take an
// OpRegistryInterface. Functions accepting a
// (const) OpRegistryInterface* may call LookUp() from multiple threads.
class OpRegistryInterface {
public:
virtual ~OpRegistryInterface() = default;
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
// registered under that name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
virtual Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const = 0;
// Shorthand for calling LookUp to get the OpDef.
Status LookUpOpDef(const std::string& op_type_name,
const OpDef** op_def) const;
};
// The standard implementation of OpRegistryInterface, along with a
// global singleton used for registering ops via the REGISTER
// macros below. Thread-safe.
//
// Example registration:
// OpRegistry::Global()->Register(
// [](OpRegistrationData* op_reg_data)->Status {
// // Populate *op_reg_data here.
// return OkStatus();
// });
class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
OpRegistry();
void Register(const OpRegistrationDataFactory& op_data_factory);
Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
// Returns OpRegistrationData* of registered op type, else returns nullptr.
const OpRegistrationData* LookUp(const std::string& op_type_name) const;
// Fills *ops with all registered OpDefs (except those with names
// starting with '_' if include_internal == false) sorted in
// ascending alphabetical order.
void Export(bool include_internal, OpList* ops) const;
// Returns ASCII-format OpList for all registered OpDefs (except
// those with names starting with '_' if include_internal == false).
std::string DebugString(bool include_internal) const;
// A singleton available at startup.
static OpRegistry* Global();
// Get all registered ops.
void GetRegisteredOps(std::vector<OpDef>* op_defs);
// Get all `OpRegistrationData`s.
void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
// Registers a function that validates op registry.
void RegisterValidator(
std::function<Status(const OpRegistryInterface&)> validator) {
op_registry_validator_ = std::move(validator);
}
// Watcher, a function object.
// The watcher, if set by SetWatcher(), is called every time an op is
// registered via the Register function. The watcher is passed the Status
// obtained from building and adding the OpDef to the registry, and the OpDef
// itself if it was successfully built. A watcher returns a Status which is in
// turn returned as the final registration status.
typedef std::function<Status(const Status&, const OpDef&)> Watcher;
// An OpRegistry object has only one watcher. This interface is not thread
// safe, as different clients are free to set the watcher any time.
// Clients are expected to atomically perform the following sequence of
// operations :
// SetWatcher(a_watcher);
// Register some ops;
// op_registry->ProcessRegistrations();
// SetWatcher(nullptr);
// Returns a non-OK status if a non-null watcher is over-written by another
// non-null watcher.
Status SetWatcher(const Watcher& watcher);
// Process the current list of deferred registrations. Note that calls to
// Export, LookUp and DebugString would also implicitly process the deferred
// registrations. Returns the status of the first failed op registration or
// OkStatus() otherwise.
Status ProcessRegistrations() const;
// Defer the registrations until a later call to a function that processes
// deferred registrations are made. Normally, registrations that happen after
// calls to Export, LookUp, ProcessRegistrations and DebugString are processed
// immediately. Call this to defer future registrations.
void DeferRegistrations();
// Clear the registrations that have been deferred.
void ClearDeferredRegistrations();
private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
// time it is called. Prints a fatal log if any op registration fails.
bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Calls the functions in deferred_ and registers their OpDef's
// It returns the Status of the first failed op registration or OkStatus()
// otherwise.
Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
// Values are owned.
mutable absl::flat_hash_map<string, std::unique_ptr<const OpRegistrationData>>
registry_ TF_GUARDED_BY(mu_);
mutable bool initialized_ TF_GUARDED_BY(mu_);
// Registry watcher.
mutable Watcher watcher_ TF_GUARDED_BY(mu_);
std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
};
// An adapter to allow an OpList to be used as an OpRegistryInterface.
//
// Note that shape inference functions are not passed in to OpListOpRegistry, so
// it will return an unusable shape inference function for every op it supports;
// therefore, it should only be used in contexts where this is okay.
class OpListOpRegistry : public OpRegistryInterface {
public:
// Does not take ownership of op_list, *op_list must outlive *this.
explicit OpListOpRegistry(const OpList* op_list);
Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
// Returns OpRegistrationData* of op type in list, else returns nullptr.
const OpRegistrationData* LookUp(const std::string& op_type_name) const;
private:
// Values are owned.
absl::flat_hash_map<string, std::unique_ptr<const OpRegistrationData>> index_;
};
// Support for defining the OpDef (specifying the semantics of the Op and how
// it should be created) and registering it in the OpRegistry::Global()
// registry. Usage:
//
// REGISTER_OP("my_op_name")
// .Attr("<name>:<type>")
// .Attr("<name>:<type>=<default>")
// .Input("<name>:<type-expr>")
// .Input("<name>:Ref(<type-expr>)")
// .Output("<name>:<type-expr>")
// .Doc(R"(
// <1-line summary>
// <rest of the description (potentially many lines)>
// <name-of-attr-input-or-output>: <description of name>
// <name-of-attr-input-or-output>: <description of name;
// if long, indent the description on subsequent lines>
// )");
//
// Note: .Doc() should be last.
// For details, see the OpDefBuilder class in op_def_builder.h.
namespace register_op {
class OpDefBuilderWrapper {
public:
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper& Attr(std::string spec) {
builder_.Attr(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE {
return Attr(std::string(spec));
}
OpDefBuilderWrapper& Input(std::string spec) {
builder_.Input(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE {
return Input(std::string(spec));
}
OpDefBuilderWrapper& Output(std::string spec) {
builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE {
return Output(std::string(spec));
}
OpDefBuilderWrapper& SetIsCommutative() {
builder_.SetIsCommutative();
return *this;
}
OpDefBuilderWrapper& SetIsAggregate() {
builder_.SetIsAggregate();
return *this;
}
OpDefBuilderWrapper& SetIsStateful() {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetDoNotOptimize() {
// We don't have a separate flag to disable optimizations such as constant
// folding and CSE so we reuse the stateful flag.
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
builder_.Deprecated(version, std::move(explanation));
return *this;
}
OpDefBuilderWrapper& Doc(std::string text) {
builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
builder_.SetShapeFn(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetIsDistributedCommunication() {
builder_.SetIsDistributedCommunication();
return *this;
}
OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
builder_.SetTypeConstructor(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetForwardTypeFn(TypeInferenceFn fn) {
builder_.SetForwardTypeFn(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetReverseTypeFn(int input_number, TypeInferenceFn fn) {
builder_.SetReverseTypeFn(input_number, std::move(fn));
return *this;
}
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
InitOnStartupMarker operator()();
private:
mutable ::tensorflow::OpDefBuilder builder_;
};
} // namespace register_op
#define REGISTER_OP_IMPL(ctr, name, is_system_op) \
static ::tensorflow::InitOnStartupMarker const register_op##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
<< ::tensorflow::register_op::OpDefBuilderWrapper(name)
#define REGISTER_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
// that the op is registered unconditionally even when selective
// registration is used.
#define REGISTER_SYSTEM_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_