2
2
mirror of https://github.com/octoleo/restic.git synced 2024-06-01 08:30:49 +00:00
restic/vendor/github.com/Azure/go-autorest/autorest/azure/example/main.go
2017-10-22 10:07:36 +02:00

273 lines
7.7 KiB
Go

package main
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"golang.org/x/crypto/pkcs12"
)
const (
resourceGroupURLTemplate = "https://management.azure.com"
apiVersion = "2015-01-01"
nativeAppClientID = "a87032a7-203c-4bf7-913c-44c50d23409a"
resource = "https://management.core.windows.net/"
)
var (
mode string
tenantID string
subscriptionID string
applicationID string
tokenCachePath string
forceRefresh bool
impatient bool
certificatePath string
)
func init() {
flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
flag.StringVar(&applicationID, "applicationId", "", "application id")
flag.StringVar(&tenantID, "tenantId", "", "tenant id")
flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
flag.Parse()
log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
mode, certificatePath, applicationID, tenantID, subscriptionID)
if mode == "certificate" &&
(strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
}
if mode != "certificate" && mode != "device" {
log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
}
if mode == "device" && strings.TrimSpace(applicationID) == "" {
log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
applicationID = nativeAppClientID
}
if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
}
if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
}
}
func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
token, err := adal.LoadToken(tokenCachePath)
if err != nil {
return nil, fmt.Errorf("failed to load token from cache: %v", err)
}
spt, _ := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
clientID,
resource,
*token,
callbacks...)
return spt, nil
}
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, err
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
}
return certificate, rsaPrivateKey, nil
}
func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spt, _ := adal.NewServicePrincipalTokenFromCertificate(
oauthConfig,
clientID,
certificate,
rsaPrivateKey,
resource,
callbacks...)
return spt, nil
}
func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
oauthClient := &autorest.Client{}
deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
if err != nil {
return nil, fmt.Errorf("failed to start device auth flow: %s", err)
}
fmt.Println(*deviceCode.Message)
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
clientID,
resource,
*token,
callbacks...)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
}
return spt, nil
}
func printResourceGroups(client *autorest.Client) error {
p := map[string]interface{}{"subscription-id": subscriptionID}
q := map[string]interface{}{"api-version": apiVersion}
req, _ := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL(resourceGroupURLTemplate),
autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
autorest.WithQueryParameters(q))
resp, err := autorest.SendWithSender(client, req)
if err != nil {
return err
}
value := struct {
ResourceGroups []struct {
Name string `json:"name"`
} `json:"value"`
}{}
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
err = dec.Decode(&value)
if err != nil {
return err
}
var groupNames = make([]string, len(value.ResourceGroups))
for i, name := range value.ResourceGroups {
groupNames[i] = name.Name
}
log.Println("Groups:", strings.Join(groupNames, ", "))
return err
}
func saveToken(spt adal.Token) {
if tokenCachePath != "" {
err := adal.SaveToken(tokenCachePath, 0600, spt)
if err != nil {
log.Println("error saving token", err)
} else {
log.Println("saved token to", tokenCachePath)
}
}
}
func main() {
var spt *adal.ServicePrincipalToken
var err error
callback := func(t adal.Token) error {
log.Println("refresh callback was called")
saveToken(t)
return nil
}
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
if err != nil {
panic(err)
}
if tokenCachePath != "" {
log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
if err != nil {
spt = nil // just in case, this is the condition below
log.Println("loading from cache failed:", err)
}
}
if spt == nil {
log.Println("authenticating via 'mode'", mode)
switch mode {
case "device":
spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
case "certificate":
spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
}
if err != nil {
log.Fatalln("failed to retrieve token:", err)
}
// should save it as soon as you get it since Refresh won't be called for some time
if tokenCachePath != "" {
saveToken(spt.Token)
}
}
client := &autorest.Client{}
client.Authorizer = autorest.NewBearerAuthorizer(spt)
printResourceGroups(client)
if forceRefresh {
err = spt.Refresh()
if err != nil {
panic(err)
}
printResourceGroups(client)
}
}