Skip to content

Commit ca8e327

Browse files
Revert "Deprecate gcs-config (#1024)"
This reverts commit 9702a15.
1 parent 59ddbc4 commit ca8e327

File tree

10 files changed

+642
-12
lines changed

10 files changed

+642
-12
lines changed

tensorflow_io/core/BUILD

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -744,18 +744,6 @@ cc_binary(
744744
}),
745745
)
746746

747-
cc_binary(
748-
name = "python/ops/libtensorflow_io_plugins.so",
749-
copts = tf_io_copts(),
750-
linkshared = 1,
751-
deps = select({
752-
"//tensorflow_io/core:static_build_on": [],
753-
"//conditions:default": [
754-
"//tensorflow_io/core/plugins:plugins",
755-
],
756-
}),
757-
)
758-
759747
cc_binary(
760748
name = "python/ops/libtensorflow_io_golang.so",
761749
copts = tf_io_copts(),

tensorflow_io/gcs/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
load(
6+
"//:tools/build/tensorflow_io.bzl",
7+
"tf_io_copts",
8+
)
9+
10+
cc_library(
11+
name = "gcs_config_ops",
12+
srcs = [
13+
"kernels/gcs_config_op_kernels.cc",
14+
"ops/gcs_config_ops.cc",
15+
],
16+
copts = tf_io_copts(),
17+
linkstatic = True,
18+
deps = [
19+
"@curl",
20+
"@jsoncpp_git//:jsoncpp",
21+
"@local_config_tf//:libtensorflow_framework",
22+
"@local_config_tf//:tf_header_lib",
23+
],
24+
alwayslink = 1,
25+
)

tensorflow_io/gcs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Cloud Storage (GCS) ##
2+
3+
The Google Cloud Storage ops allow the user to configure the GCS File System.

tensorflow_io/gcs/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module for cloud ops."""
16+
17+
18+
from tensorflow.python.util.all_util import remove_undocumented
19+
20+
# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top
21+
from tensorflow_io.gcs.python.ops.gcs_config_ops import *
22+
23+
_allowed_symbols = [
24+
"configure_colab_session",
25+
"configure_gcs",
26+
"BlockCacheParams",
27+
"ConfigureGcsHook",
28+
]
29+
remove_undocumented(__name__, _allowed_symbols)
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <sstream>
17+
18+
#include "include/json/json.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/tensor_shape.h"
21+
#include "tensorflow/core/platform/cloud/curl_http_request.h"
22+
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
23+
#include "tensorflow/core/platform/cloud/oauth_client.h"
24+
#include "tensorflow/core/util/ptr_util.h"
25+
26+
namespace tensorflow {
27+
namespace {
28+
29+
// The default initial delay between retries with exponential backoff.
30+
constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
31+
32+
// The minimum time delta between now and the token expiration time
33+
// for the token to be re-used.
34+
constexpr int kExpirationTimeMarginSec = 60;
35+
36+
// The URL to retrieve the auth bearer token via OAuth with a refresh token.
37+
constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
38+
39+
// The URL to retrieve the auth bearer token via OAuth with a private key.
40+
constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
41+
42+
// The authentication token scope to request.
43+
constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
44+
45+
Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) {
46+
DCHECK(fs != nullptr);
47+
*fs = nullptr;
48+
49+
FileSystem* filesystem = nullptr;
50+
TF_RETURN_IF_ERROR(
51+
ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem));
52+
if (filesystem == nullptr) {
53+
return errors::FailedPrecondition("The GCS file system is not registered.");
54+
}
55+
56+
*fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem);
57+
if (*fs == nullptr) {
58+
return errors::Internal(
59+
"The filesystem registered under the 'gs://' scheme was not a "
60+
"tensorflow::RetryingGcsFileSystem*.");
61+
}
62+
return Status::OK();
63+
}
64+
65+
template <typename T>
66+
Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
67+
T* output) {
68+
const Tensor* argument_t;
69+
TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
70+
if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
71+
return errors::InvalidArgument(argument_name, " must be a scalar");
72+
}
73+
*output = argument_t->scalar<T>()();
74+
return Status::OK();
75+
}
76+
77+
// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem.
78+
class GcsCredentialsOpKernel : public OpKernel {
79+
public:
80+
explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
81+
void Compute(OpKernelContext* ctx) override {
82+
// Get a handle to the GCS file system.
83+
RetryingGcsFileSystem* gcs = nullptr;
84+
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
85+
86+
tstring json_string;
87+
OP_REQUIRES_OK(ctx,
88+
ParseScalarArgument<tstring>(ctx, "json", &json_string));
89+
90+
Json::Value json;
91+
Json::Reader reader;
92+
std::stringstream json_stream(json_string);
93+
OP_REQUIRES(ctx, reader.parse(json_stream, json),
94+
errors::InvalidArgument("Could not parse json: ", json_string));
95+
96+
OP_REQUIRES(
97+
ctx, json.isMember("refresh_token") || json.isMember("private_key"),
98+
errors::InvalidArgument("JSON format incompatible; did not find fields "
99+
"`refresh_token` or `private_key`."));
100+
101+
auto provider =
102+
tensorflow::MakeUnique<ConstantAuthProvider>(json, ctx->env());
103+
104+
// Test getting a token
105+
string dummy_token;
106+
OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token));
107+
OP_REQUIRES(ctx, !dummy_token.empty(),
108+
errors::InvalidArgument(
109+
"Could not retrieve a token with the given credentials."));
110+
111+
// Set the provider.
112+
gcs->underlying()->SetAuthProvider(std::move(provider));
113+
}
114+
115+
private:
116+
class ConstantAuthProvider : public AuthProvider {
117+
public:
118+
ConstantAuthProvider(const Json::Value& json,
119+
std::unique_ptr<OAuthClient> oauth_client, Env* env,
120+
int64 initial_retry_delay_usec)
121+
: json_(json),
122+
oauth_client_(std::move(oauth_client)),
123+
env_(env),
124+
initial_retry_delay_usec_(initial_retry_delay_usec) {}
125+
126+
ConstantAuthProvider(const Json::Value& json, Env* env)
127+
: ConstantAuthProvider(json, tensorflow::MakeUnique<OAuthClient>(), env,
128+
kInitialRetryDelayUsec) {}
129+
130+
~ConstantAuthProvider() override {}
131+
132+
Status GetToken(string* token) override {
133+
mutex_lock l(mu_);
134+
const uint64 now_sec = env_->NowSeconds();
135+
136+
if (!current_token_.empty() &&
137+
now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
138+
*token = current_token_;
139+
return Status::OK();
140+
}
141+
if (json_.isMember("refresh_token")) {
142+
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
143+
json_, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
144+
} else if (json_.isMember("private_key")) {
145+
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
146+
json_, kOAuthV4Url, kOAuthScope, &current_token_,
147+
&expiration_timestamp_sec_));
148+
} else {
149+
return errors::FailedPrecondition(
150+
"Unexpected content of the JSON credentials file.");
151+
}
152+
153+
*token = current_token_;
154+
return Status::OK();
155+
}
156+
157+
private:
158+
Json::Value json_;
159+
std::unique_ptr<OAuthClient> oauth_client_;
160+
Env* env_;
161+
162+
mutex mu_;
163+
string current_token_ TF_GUARDED_BY(mu_);
164+
uint64 expiration_timestamp_sec_ TF_GUARDED_BY(mu_) = 0;
165+
166+
// The initial delay for exponential backoffs when retrying failed calls.
167+
const int64 initial_retry_delay_usec_;
168+
TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider);
169+
};
170+
};
171+
172+
REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureCredentials").Device(DEVICE_CPU),
173+
GcsCredentialsOpKernel);
174+
175+
class GcsBlockCacheOpKernel : public OpKernel {
176+
public:
177+
explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
178+
void Compute(OpKernelContext* ctx) override {
179+
// Get a handle to the GCS file system.
180+
RetryingGcsFileSystem* gcs = nullptr;
181+
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
182+
183+
size_t max_cache_size, block_size, max_staleness;
184+
OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size",
185+
&max_cache_size));
186+
OP_REQUIRES_OK(ctx,
187+
ParseScalarArgument<size_t>(ctx, "block_size", &block_size));
188+
OP_REQUIRES_OK(
189+
ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness));
190+
191+
if (gcs->underlying()->block_size() == block_size &&
192+
gcs->underlying()->max_bytes() == max_cache_size &&
193+
gcs->underlying()->max_staleness() == max_staleness) {
194+
LOG(INFO) << "Skipping resetting the GCS block cache.";
195+
return;
196+
}
197+
gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size,
198+
max_staleness);
199+
}
200+
};
201+
202+
REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureBlockCache").Device(DEVICE_CPU),
203+
GcsBlockCacheOpKernel);
204+
205+
} // namespace
206+
} // namespace tensorflow
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
22+
using shape_inference::InferenceContext;
23+
24+
REGISTER_OP("IO>GcsConfigureCredentials")
25+
.Input("json: string")
26+
.SetShapeFn(shape_inference::NoOutputs)
27+
.Doc(R"doc(
28+
Configures the credentials used by the GCS client of the local TF runtime.
29+
The json input can be of the format:
30+
1. Refresh Token:
31+
{
32+
"client_id": "<redacted>",
33+
"client_secret": "<redacted>",
34+
"refresh_token: "<redacted>",
35+
"type": "authorized_user",
36+
}
37+
2. Service Account:
38+
{
39+
"type": "service_account",
40+
"project_id": "<redacted>",
41+
"private_key_id": "<redacted>",
42+
"private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
43+
"client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
44+
"client_id": "<REDACTED>",
45+
# Some additional fields elided
46+
}
47+
Note the credentials established through this method are shared across all
48+
sessions run on this runtime.
49+
Note be sure to feed the inputs to this op to ensure the credentials are not
50+
stored in a constant op within the graph that might accidentally be checkpointed
51+
or in other ways be persisted or exfiltrated.
52+
)doc");
53+
54+
REGISTER_OP("IO>GcsConfigureBlockCache")
55+
.Input("max_cache_size: uint64")
56+
.Input("block_size: uint64")
57+
.Input("max_staleness: uint64")
58+
.SetShapeFn(shape_inference::NoOutputs)
59+
.Doc(R"doc(
60+
Re-configures the GCS block cache with the new configuration values.
61+
If the values are the same as already configured values, this op is a no-op. If
62+
they are different, the current contents of the block cache is dropped, and a
63+
new block cache is created fresh.
64+
)doc");
65+
66+
} // namespace tensorflow

tensorflow_io/gcs/python/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""This module contains Python API methods for GCS integration."""
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""This module contains the Python API methods for GCS integration."""

0 commit comments

Comments
 (0)